db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use collections::HashMap;
   6use futures::StreamExt;
   7use serde::{Deserialize, Serialize};
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{
  10    migrate::{Migrate as _, Migration, MigrationSource},
  11    types::Uuid,
  12    FromRow, QueryBuilder,
  13};
  14use std::{cmp, ops::Range, path::Path, time::Duration};
  15use time::{OffsetDateTime, PrimitiveDateTime};
  16
  17#[async_trait]
  18pub trait Db: Send + Sync {
  19    async fn create_user(
  20        &self,
  21        email_address: &str,
  22        admin: bool,
  23        params: NewUserParams,
  24    ) -> Result<NewUserResult>;
  25    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  26    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  27    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  28    async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
  29    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  30    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  31    async fn get_user_by_github_account(
  32        &self,
  33        github_login: &str,
  34        github_user_id: Option<i32>,
  35    ) -> Result<Option<User>>;
  36    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  37    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  38    async fn destroy_user(&self, id: UserId) -> Result<()>;
  39
  40    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
  41    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  42    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  43    async fn create_invite_from_code(
  44        &self,
  45        code: &str,
  46        email_address: &str,
  47        device_id: Option<&str>,
  48    ) -> Result<Invite>;
  49
  50    async fn create_signup(&self, signup: Signup) -> Result<()>;
  51    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
  52    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
  53    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
  54    async fn create_user_from_invite(
  55        &self,
  56        invite: &Invite,
  57        user: NewUserParams,
  58    ) -> Result<Option<NewUserResult>>;
  59
  60    /// Registers a new project for the given user.
  61    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  62
  63    /// Unregisters a project for the given project id.
  64    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  65
  66    /// Update file counts by extension for the given project and worktree.
  67    async fn update_worktree_extensions(
  68        &self,
  69        project_id: ProjectId,
  70        worktree_id: u64,
  71        extensions: HashMap<String, u32>,
  72    ) -> Result<()>;
  73
  74    /// Get the file counts on the given project keyed by their worktree and extension.
  75    async fn get_project_extensions(
  76        &self,
  77        project_id: ProjectId,
  78    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  79
  80    /// Record which users have been active in which projects during
  81    /// a given period of time.
  82    async fn record_user_activity(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        active_projects: &[(UserId, ProjectId)],
  86    ) -> Result<()>;
  87
  88    /// Get the number of users who have been active in the given
  89    /// time period for at least the given time duration.
  90    async fn get_active_user_count(
  91        &self,
  92        time_period: Range<OffsetDateTime>,
  93        min_duration: Duration,
  94        only_collaborative: bool,
  95    ) -> Result<usize>;
  96
  97    /// Get the users that have been most active during the given time period,
  98    /// along with the amount of time they have been active in each project.
  99    async fn get_top_users_activity_summary(
 100        &self,
 101        time_period: Range<OffsetDateTime>,
 102        max_user_count: usize,
 103    ) -> Result<Vec<UserActivitySummary>>;
 104
 105    /// Get the project activity for the given user and time period.
 106    async fn get_user_activity_timeline(
 107        &self,
 108        time_period: Range<OffsetDateTime>,
 109        user_id: UserId,
 110    ) -> Result<Vec<UserActivityPeriod>>;
 111
 112    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
 113    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
 114    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 115    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 116    async fn dismiss_contact_notification(
 117        &self,
 118        responder_id: UserId,
 119        requester_id: UserId,
 120    ) -> Result<()>;
 121    async fn respond_to_contact_request(
 122        &self,
 123        responder_id: UserId,
 124        requester_id: UserId,
 125        accept: bool,
 126    ) -> Result<()>;
 127
 128    async fn create_access_token_hash(
 129        &self,
 130        user_id: UserId,
 131        access_token_hash: &str,
 132        max_access_token_count: usize,
 133    ) -> Result<()>;
 134    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 135
 136    #[cfg(any(test, feature = "seed-support"))]
 137    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 138    #[cfg(any(test, feature = "seed-support"))]
 139    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 140    #[cfg(any(test, feature = "seed-support"))]
 141    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 142    #[cfg(any(test, feature = "seed-support"))]
 143    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 144    #[cfg(any(test, feature = "seed-support"))]
 145
 146    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 147    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 148    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 149        -> Result<bool>;
 150
 151    #[cfg(any(test, feature = "seed-support"))]
 152    async fn add_channel_member(
 153        &self,
 154        channel_id: ChannelId,
 155        user_id: UserId,
 156        is_admin: bool,
 157    ) -> Result<()>;
 158    async fn create_channel_message(
 159        &self,
 160        channel_id: ChannelId,
 161        sender_id: UserId,
 162        body: &str,
 163        timestamp: OffsetDateTime,
 164        nonce: u128,
 165    ) -> Result<MessageId>;
 166    async fn get_channel_messages(
 167        &self,
 168        channel_id: ChannelId,
 169        count: usize,
 170        before_id: Option<MessageId>,
 171    ) -> Result<Vec<ChannelMessage>>;
 172
 173    #[cfg(test)]
 174    async fn teardown(&self, url: &str);
 175
 176    #[cfg(test)]
 177    fn as_fake(&self) -> Option<&FakeDb>;
 178}
 179
 180#[cfg(any(test, debug_assertions))]
 181pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
 182    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 183
 184#[cfg(not(any(test, debug_assertions)))]
 185pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
 186
 187pub struct PostgresDb {
 188    pool: sqlx::PgPool,
 189}
 190
 191impl PostgresDb {
 192    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 193        let pool = DbOptions::new()
 194            .max_connections(max_connections)
 195            .connect(url)
 196            .await
 197            .context("failed to connect to postgres database")?;
 198        Ok(Self { pool })
 199    }
 200
 201    pub async fn migrate(
 202        &self,
 203        migrations_path: &Path,
 204        ignore_checksum_mismatch: bool,
 205    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 206        let migrations = MigrationSource::resolve(migrations_path)
 207            .await
 208            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 209
 210        let mut conn = self.pool.acquire().await?;
 211
 212        conn.ensure_migrations_table().await?;
 213        let applied_migrations: HashMap<_, _> = conn
 214            .list_applied_migrations()
 215            .await?
 216            .into_iter()
 217            .map(|m| (m.version, m))
 218            .collect();
 219
 220        let mut new_migrations = Vec::new();
 221        for migration in migrations {
 222            match applied_migrations.get(&migration.version) {
 223                Some(applied_migration) => {
 224                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 225                    {
 226                        Err(anyhow!(
 227                            "checksum mismatch for applied migration {}",
 228                            migration.description
 229                        ))?;
 230                    }
 231                }
 232                None => {
 233                    let elapsed = conn.apply(&migration).await?;
 234                    new_migrations.push((migration, elapsed));
 235                }
 236            }
 237        }
 238
 239        Ok(new_migrations)
 240    }
 241
 242    pub fn fuzzy_like_string(string: &str) -> String {
 243        let mut result = String::with_capacity(string.len() * 2 + 1);
 244        for c in string.chars() {
 245            if c.is_alphanumeric() {
 246                result.push('%');
 247                result.push(c);
 248            }
 249        }
 250        result.push('%');
 251        result
 252    }
 253}
 254
 255#[async_trait]
 256impl Db for PostgresDb {
 257    // users
 258
 259    async fn create_user(
 260        &self,
 261        email_address: &str,
 262        admin: bool,
 263        params: NewUserParams,
 264    ) -> Result<NewUserResult> {
 265        let query = "
 266            INSERT INTO users (email_address, github_login, github_user_id, admin)
 267            VALUES ($1, $2, $3, $4)
 268            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 269            RETURNING id, metrics_id::text
 270        ";
 271        let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 272            .bind(email_address)
 273            .bind(params.github_login)
 274            .bind(params.github_user_id)
 275            .bind(admin)
 276            .fetch_one(&self.pool)
 277            .await?;
 278        Ok(NewUserResult {
 279            user_id,
 280            metrics_id,
 281            signup_device_id: None,
 282            inviting_user_id: None,
 283        })
 284    }
 285
 286    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 287        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 288        Ok(sqlx::query_as(query)
 289            .bind(limit as i32)
 290            .bind((page * limit) as i32)
 291            .fetch_all(&self.pool)
 292            .await?)
 293    }
 294
 295    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 296        let like_string = Self::fuzzy_like_string(name_query);
 297        let query = "
 298            SELECT users.*
 299            FROM users
 300            WHERE github_login ILIKE $1
 301            ORDER BY github_login <-> $2
 302            LIMIT $3
 303        ";
 304        Ok(sqlx::query_as(query)
 305            .bind(like_string)
 306            .bind(name_query)
 307            .bind(limit as i32)
 308            .fetch_all(&self.pool)
 309            .await?)
 310    }
 311
 312    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 313        let users = self.get_users_by_ids(vec![id]).await?;
 314        Ok(users.into_iter().next())
 315    }
 316
 317    async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 318        let query = "
 319            SELECT metrics_id::text
 320            FROM users
 321            WHERE id = $1
 322        ";
 323        Ok(sqlx::query_scalar(query)
 324            .bind(id)
 325            .fetch_one(&self.pool)
 326            .await?)
 327    }
 328
 329    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 330        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 331        let query = "
 332            SELECT users.*
 333            FROM users
 334            WHERE users.id = ANY ($1)
 335        ";
 336        Ok(sqlx::query_as(query)
 337            .bind(&ids)
 338            .fetch_all(&self.pool)
 339            .await?)
 340    }
 341
 342    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 343        let query = format!(
 344            "
 345            SELECT users.*
 346            FROM users
 347            WHERE invite_count = 0
 348            AND inviter_id IS{} NULL
 349            ",
 350            if invited_by_another_user { " NOT" } else { "" }
 351        );
 352
 353        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 354    }
 355
 356    async fn get_user_by_github_account(
 357        &self,
 358        github_login: &str,
 359        github_user_id: Option<i32>,
 360    ) -> Result<Option<User>> {
 361        if let Some(github_user_id) = github_user_id {
 362            let mut user = sqlx::query_as::<_, User>(
 363                "
 364                UPDATE users
 365                SET github_login = $1
 366                WHERE github_user_id = $2
 367                RETURNING *
 368                ",
 369            )
 370            .bind(github_login)
 371            .bind(github_user_id)
 372            .fetch_optional(&self.pool)
 373            .await?;
 374
 375            if user.is_none() {
 376                user = sqlx::query_as::<_, User>(
 377                    "
 378                    UPDATE users
 379                    SET github_user_id = $1
 380                    WHERE github_login = $2
 381                    RETURNING *
 382                    ",
 383                )
 384                .bind(github_user_id)
 385                .bind(github_login)
 386                .fetch_optional(&self.pool)
 387                .await?;
 388            }
 389
 390            Ok(user)
 391        } else {
 392            Ok(sqlx::query_as(
 393                "
 394                SELECT * FROM users
 395                WHERE github_login = $1
 396                LIMIT 1
 397                ",
 398            )
 399            .bind(github_login)
 400            .fetch_optional(&self.pool)
 401            .await?)
 402        }
 403    }
 404
 405    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 406        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 407        Ok(sqlx::query(query)
 408            .bind(is_admin)
 409            .bind(id.0)
 410            .execute(&self.pool)
 411            .await
 412            .map(drop)?)
 413    }
 414
 415    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 416        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 417        Ok(sqlx::query(query)
 418            .bind(connected_once)
 419            .bind(id.0)
 420            .execute(&self.pool)
 421            .await
 422            .map(drop)?)
 423    }
 424
 425    async fn destroy_user(&self, id: UserId) -> Result<()> {
 426        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 427        sqlx::query(query)
 428            .bind(id.0)
 429            .execute(&self.pool)
 430            .await
 431            .map(drop)?;
 432        let query = "DELETE FROM users WHERE id = $1;";
 433        Ok(sqlx::query(query)
 434            .bind(id.0)
 435            .execute(&self.pool)
 436            .await
 437            .map(drop)?)
 438    }
 439
 440    // signups
 441
 442    async fn create_signup(&self, signup: Signup) -> Result<()> {
 443        sqlx::query(
 444            "
 445            INSERT INTO signups
 446            (
 447                email_address,
 448                email_confirmation_code,
 449                email_confirmation_sent,
 450                platform_linux,
 451                platform_mac,
 452                platform_windows,
 453                platform_unknown,
 454                editor_features,
 455                programming_languages,
 456                device_id
 457            )
 458            VALUES
 459                ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
 460            RETURNING id
 461            ",
 462        )
 463        .bind(&signup.email_address)
 464        .bind(&random_email_confirmation_code())
 465        .bind(&signup.platform_linux)
 466        .bind(&signup.platform_mac)
 467        .bind(&signup.platform_windows)
 468        .bind(&signup.editor_features)
 469        .bind(&signup.programming_languages)
 470        .bind(&signup.device_id)
 471        .execute(&self.pool)
 472        .await?;
 473        Ok(())
 474    }
 475
 476    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 477        Ok(sqlx::query_as(
 478            "
 479            SELECT
 480                COUNT(*) as count,
 481                COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 482                COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 483                COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 484                COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 485            FROM (
 486                SELECT *
 487                FROM signups
 488                WHERE
 489                    NOT email_confirmation_sent
 490            ) AS unsent
 491            ",
 492        )
 493        .fetch_one(&self.pool)
 494        .await?)
 495    }
 496
 497    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 498        Ok(sqlx::query_as(
 499            "
 500            SELECT
 501                email_address, email_confirmation_code
 502            FROM signups
 503            WHERE
 504                NOT email_confirmation_sent AND
 505                (platform_mac OR platform_unknown)
 506            LIMIT $1
 507            ",
 508        )
 509        .bind(count as i32)
 510        .fetch_all(&self.pool)
 511        .await?)
 512    }
 513
 514    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 515        sqlx::query(
 516            "
 517            UPDATE signups
 518            SET email_confirmation_sent = 't'
 519            WHERE email_address = ANY ($1)
 520            ",
 521        )
 522        .bind(
 523            &invites
 524                .iter()
 525                .map(|s| s.email_address.as_str())
 526                .collect::<Vec<_>>(),
 527        )
 528        .execute(&self.pool)
 529        .await?;
 530        Ok(())
 531    }
 532
 533    async fn create_user_from_invite(
 534        &self,
 535        invite: &Invite,
 536        user: NewUserParams,
 537    ) -> Result<Option<NewUserResult>> {
 538        let mut tx = self.pool.begin().await?;
 539
 540        let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 541            i32,
 542            Option<UserId>,
 543            Option<UserId>,
 544            Option<String>,
 545        ) = sqlx::query_as(
 546            "
 547            SELECT id, user_id, inviting_user_id, device_id
 548            FROM signups
 549            WHERE
 550                email_address = $1 AND
 551                email_confirmation_code = $2
 552            ",
 553        )
 554        .bind(&invite.email_address)
 555        .bind(&invite.email_confirmation_code)
 556        .fetch_optional(&mut tx)
 557        .await?
 558        .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 559
 560        if existing_user_id.is_some() {
 561            return Ok(None);
 562        }
 563
 564        let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 565            "
 566            INSERT INTO users
 567            (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 568            VALUES
 569            ($1, $2, $3, 'f', $4, $5)
 570            ON CONFLICT (github_login) DO UPDATE SET
 571                email_address = excluded.email_address,
 572                github_user_id = excluded.github_user_id,
 573                admin = excluded.admin
 574            RETURNING id, metrics_id::text
 575            ",
 576        )
 577        .bind(&invite.email_address)
 578        .bind(&user.github_login)
 579        .bind(&user.github_user_id)
 580        .bind(&user.invite_count)
 581        .bind(random_invite_code())
 582        .fetch_one(&mut tx)
 583        .await?;
 584
 585        sqlx::query(
 586            "
 587            UPDATE signups
 588            SET user_id = $1
 589            WHERE id = $2
 590            ",
 591        )
 592        .bind(&user_id)
 593        .bind(&signup_id)
 594        .execute(&mut tx)
 595        .await?;
 596
 597        if let Some(inviting_user_id) = inviting_user_id {
 598            let id: Option<UserId> = sqlx::query_scalar(
 599                "
 600                UPDATE users
 601                SET invite_count = invite_count - 1
 602                WHERE id = $1 AND invite_count > 0
 603                RETURNING id
 604                ",
 605            )
 606            .bind(&inviting_user_id)
 607            .fetch_optional(&mut tx)
 608            .await?;
 609
 610            if id.is_none() {
 611                Err(Error::Http(
 612                    StatusCode::UNAUTHORIZED,
 613                    "no invites remaining".to_string(),
 614                ))?;
 615            }
 616
 617            sqlx::query(
 618                "
 619                INSERT INTO contacts
 620                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 621                VALUES
 622                    ($1, $2, 't', 't', 't')
 623                ON CONFLICT DO NOTHING
 624                ",
 625            )
 626            .bind(inviting_user_id)
 627            .bind(user_id)
 628            .execute(&mut tx)
 629            .await?;
 630        }
 631
 632        tx.commit().await?;
 633        Ok(Some(NewUserResult {
 634            user_id,
 635            metrics_id,
 636            inviting_user_id,
 637            signup_device_id,
 638        }))
 639    }
 640
 641    // invite codes
 642
 643    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 644        let mut tx = self.pool.begin().await?;
 645        if count > 0 {
 646            sqlx::query(
 647                "
 648                UPDATE users
 649                SET invite_code = $1
 650                WHERE id = $2 AND invite_code IS NULL
 651            ",
 652            )
 653            .bind(random_invite_code())
 654            .bind(id)
 655            .execute(&mut tx)
 656            .await?;
 657        }
 658
 659        sqlx::query(
 660            "
 661            UPDATE users
 662            SET invite_count = $1
 663            WHERE id = $2
 664            ",
 665        )
 666        .bind(count as i32)
 667        .bind(id)
 668        .execute(&mut tx)
 669        .await?;
 670        tx.commit().await?;
 671        Ok(())
 672    }
 673
 674    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 675        let result: Option<(String, i32)> = sqlx::query_as(
 676            "
 677                SELECT invite_code, invite_count
 678                FROM users
 679                WHERE id = $1 AND invite_code IS NOT NULL 
 680            ",
 681        )
 682        .bind(id)
 683        .fetch_optional(&self.pool)
 684        .await?;
 685        if let Some((code, count)) = result {
 686            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 687        } else {
 688            Ok(None)
 689        }
 690    }
 691
 692    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 693        sqlx::query_as(
 694            "
 695                SELECT *
 696                FROM users
 697                WHERE invite_code = $1
 698            ",
 699        )
 700        .bind(code)
 701        .fetch_optional(&self.pool)
 702        .await?
 703        .ok_or_else(|| {
 704            Error::Http(
 705                StatusCode::NOT_FOUND,
 706                "that invite code does not exist".to_string(),
 707            )
 708        })
 709    }
 710
 711    async fn create_invite_from_code(
 712        &self,
 713        code: &str,
 714        email_address: &str,
 715        device_id: Option<&str>,
 716    ) -> Result<Invite> {
 717        let mut tx = self.pool.begin().await?;
 718
 719        let existing_user: Option<UserId> = sqlx::query_scalar(
 720            "
 721            SELECT id
 722            FROM users
 723            WHERE email_address = $1
 724            ",
 725        )
 726        .bind(email_address)
 727        .fetch_optional(&mut tx)
 728        .await?;
 729        if existing_user.is_some() {
 730            Err(anyhow!("email address is already in use"))?;
 731        }
 732
 733        let row: Option<(UserId, i32)> = sqlx::query_as(
 734            "
 735            SELECT id, invite_count
 736            FROM users
 737            WHERE invite_code = $1
 738            ",
 739        )
 740        .bind(code)
 741        .fetch_optional(&mut tx)
 742        .await?;
 743
 744        let (inviter_id, invite_count) = match row {
 745            Some(row) => row,
 746            None => Err(Error::Http(
 747                StatusCode::NOT_FOUND,
 748                "invite code not found".to_string(),
 749            ))?,
 750        };
 751
 752        if invite_count == 0 {
 753            Err(Error::Http(
 754                StatusCode::UNAUTHORIZED,
 755                "no invites remaining".to_string(),
 756            ))?;
 757        }
 758
 759        let email_confirmation_code: String = sqlx::query_scalar(
 760            "
 761            INSERT INTO signups
 762            (
 763                email_address,
 764                email_confirmation_code,
 765                email_confirmation_sent,
 766                inviting_user_id,
 767                platform_linux,
 768                platform_mac,
 769                platform_windows,
 770                platform_unknown,
 771                device_id
 772            )
 773            VALUES
 774                ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
 775            ON CONFLICT (email_address)
 776            DO UPDATE SET
 777                inviting_user_id = excluded.inviting_user_id
 778            RETURNING email_confirmation_code
 779            ",
 780        )
 781        .bind(&email_address)
 782        .bind(&random_email_confirmation_code())
 783        .bind(&inviter_id)
 784        .bind(&device_id)
 785        .fetch_one(&mut tx)
 786        .await?;
 787
 788        tx.commit().await?;
 789
 790        Ok(Invite {
 791            email_address: email_address.into(),
 792            email_confirmation_code,
 793        })
 794    }
 795
 796    // projects
 797
 798    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 799        Ok(sqlx::query_scalar(
 800            "
 801            INSERT INTO projects(host_user_id)
 802            VALUES ($1)
 803            RETURNING id
 804            ",
 805        )
 806        .bind(host_user_id)
 807        .fetch_one(&self.pool)
 808        .await
 809        .map(ProjectId)?)
 810    }
 811
 812    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 813        sqlx::query(
 814            "
 815            UPDATE projects
 816            SET unregistered = 't'
 817            WHERE id = $1
 818            ",
 819        )
 820        .bind(project_id)
 821        .execute(&self.pool)
 822        .await?;
 823        Ok(())
 824    }
 825
 826    async fn update_worktree_extensions(
 827        &self,
 828        project_id: ProjectId,
 829        worktree_id: u64,
 830        extensions: HashMap<String, u32>,
 831    ) -> Result<()> {
 832        if extensions.is_empty() {
 833            return Ok(());
 834        }
 835
 836        let mut query = QueryBuilder::new(
 837            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 838        );
 839        query.push_values(extensions, |mut query, (extension, count)| {
 840            query
 841                .push_bind(project_id)
 842                .push_bind(worktree_id as i32)
 843                .push_bind(extension)
 844                .push_bind(count as i32);
 845        });
 846        query.push(
 847            "
 848            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 849            count = excluded.count
 850            ",
 851        );
 852        query.build().execute(&self.pool).await?;
 853
 854        Ok(())
 855    }
 856
 857    async fn get_project_extensions(
 858        &self,
 859        project_id: ProjectId,
 860    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 861        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 862        struct WorktreeExtension {
 863            worktree_id: i32,
 864            extension: String,
 865            count: i32,
 866        }
 867
 868        let query = "
 869            SELECT worktree_id, extension, count
 870            FROM worktree_extensions
 871            WHERE project_id = $1
 872        ";
 873        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 874            .bind(&project_id)
 875            .fetch_all(&self.pool)
 876            .await?;
 877
 878        let mut extension_counts = HashMap::default();
 879        for count in counts {
 880            extension_counts
 881                .entry(count.worktree_id as u64)
 882                .or_insert_with(HashMap::default)
 883                .insert(count.extension, count.count as usize);
 884        }
 885        Ok(extension_counts)
 886    }
 887
 888    async fn record_user_activity(
 889        &self,
 890        time_period: Range<OffsetDateTime>,
 891        projects: &[(UserId, ProjectId)],
 892    ) -> Result<()> {
 893        let query = "
 894            INSERT INTO project_activity_periods
 895            (ended_at, duration_millis, user_id, project_id)
 896            VALUES
 897            ($1, $2, $3, $4);
 898        ";
 899
 900        let mut tx = self.pool.begin().await?;
 901        let duration_millis =
 902            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 903        for (user_id, project_id) in projects {
 904            sqlx::query(query)
 905                .bind(time_period.end)
 906                .bind(duration_millis)
 907                .bind(user_id)
 908                .bind(project_id)
 909                .execute(&mut tx)
 910                .await?;
 911        }
 912        tx.commit().await?;
 913
 914        Ok(())
 915    }
 916
 917    async fn get_active_user_count(
 918        &self,
 919        time_period: Range<OffsetDateTime>,
 920        min_duration: Duration,
 921        only_collaborative: bool,
 922    ) -> Result<usize> {
 923        let mut with_clause = String::new();
 924        with_clause.push_str("WITH\n");
 925        with_clause.push_str(
 926            "
 927            project_durations AS (
 928                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 929                FROM project_activity_periods
 930                WHERE $1 < ended_at AND ended_at <= $2
 931                GROUP BY user_id, project_id
 932            ),
 933            ",
 934        );
 935        with_clause.push_str(
 936            "
 937            project_collaborators as (
 938                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 939                FROM project_durations
 940                GROUP BY project_id
 941            ),
 942            ",
 943        );
 944
 945        if only_collaborative {
 946            with_clause.push_str(
 947                "
 948                user_durations AS (
 949                    SELECT user_id, SUM(project_duration) as total_duration
 950                    FROM project_durations, project_collaborators
 951                    WHERE
 952                        project_durations.project_id = project_collaborators.project_id AND
 953                        max_collaborators > 1
 954                    GROUP BY user_id
 955                    ORDER BY total_duration DESC
 956                    LIMIT $3
 957                )
 958                ",
 959            );
 960        } else {
 961            with_clause.push_str(
 962                "
 963                user_durations AS (
 964                    SELECT user_id, SUM(project_duration) as total_duration
 965                    FROM project_durations
 966                    GROUP BY user_id
 967                    ORDER BY total_duration DESC
 968                    LIMIT $3
 969                )
 970                ",
 971            );
 972        }
 973
 974        let query = format!(
 975            "
 976            {with_clause}
 977            SELECT count(user_durations.user_id)
 978            FROM user_durations
 979            WHERE user_durations.total_duration >= $3
 980            "
 981        );
 982
 983        let count: i64 = sqlx::query_scalar(&query)
 984            .bind(time_period.start)
 985            .bind(time_period.end)
 986            .bind(min_duration.as_millis() as i64)
 987            .fetch_one(&self.pool)
 988            .await?;
 989        Ok(count as usize)
 990    }
 991
 992    async fn get_top_users_activity_summary(
 993        &self,
 994        time_period: Range<OffsetDateTime>,
 995        max_user_count: usize,
 996    ) -> Result<Vec<UserActivitySummary>> {
 997        let query = "
 998            WITH
 999                project_durations AS (
1000                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1001                    FROM project_activity_periods
1002                    WHERE $1 < ended_at AND ended_at <= $2
1003                    GROUP BY user_id, project_id
1004                ),
1005                user_durations AS (
1006                    SELECT user_id, SUM(project_duration) as total_duration
1007                    FROM project_durations
1008                    GROUP BY user_id
1009                    ORDER BY total_duration DESC
1010                    LIMIT $3
1011                ),
1012                project_collaborators as (
1013                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1014                    FROM project_durations
1015                    GROUP BY project_id
1016                )
1017            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1018            FROM user_durations, project_durations, project_collaborators, users
1019            WHERE
1020                user_durations.user_id = project_durations.user_id AND
1021                user_durations.user_id = users.id AND
1022                project_durations.project_id = project_collaborators.project_id
1023            ORDER BY total_duration DESC, user_id ASC, project_id ASC
1024        ";
1025
1026        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1027            .bind(time_period.start)
1028            .bind(time_period.end)
1029            .bind(max_user_count as i32)
1030            .fetch(&self.pool);
1031
1032        let mut result = Vec::<UserActivitySummary>::new();
1033        while let Some(row) = rows.next().await {
1034            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
1035            let project_id = project_id;
1036            let duration = Duration::from_millis(duration_millis as u64);
1037            let project_activity = ProjectActivitySummary {
1038                id: project_id,
1039                duration,
1040                max_collaborators: project_collaborators as usize,
1041            };
1042            if let Some(last_summary) = result.last_mut() {
1043                if last_summary.id == user_id {
1044                    last_summary.project_activity.push(project_activity);
1045                    continue;
1046                }
1047            }
1048            result.push(UserActivitySummary {
1049                id: user_id,
1050                project_activity: vec![project_activity],
1051                github_login,
1052            });
1053        }
1054
1055        Ok(result)
1056    }
1057
1058    async fn get_user_activity_timeline(
1059        &self,
1060        time_period: Range<OffsetDateTime>,
1061        user_id: UserId,
1062    ) -> Result<Vec<UserActivityPeriod>> {
1063        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1064
1065        let query = "
1066            SELECT
1067                project_activity_periods.ended_at,
1068                project_activity_periods.duration_millis,
1069                project_activity_periods.project_id,
1070                worktree_extensions.extension,
1071                worktree_extensions.count
1072            FROM project_activity_periods
1073            LEFT OUTER JOIN
1074                worktree_extensions
1075            ON
1076                project_activity_periods.project_id = worktree_extensions.project_id
1077            WHERE
1078                project_activity_periods.user_id = $1 AND
1079                $2 < project_activity_periods.ended_at AND
1080                project_activity_periods.ended_at <= $3
1081            ORDER BY project_activity_periods.id ASC
1082        ";
1083
1084        let mut rows = sqlx::query_as::<
1085            _,
1086            (
1087                PrimitiveDateTime,
1088                i32,
1089                ProjectId,
1090                Option<String>,
1091                Option<i32>,
1092            ),
1093        >(query)
1094        .bind(user_id)
1095        .bind(time_period.start)
1096        .bind(time_period.end)
1097        .fetch(&self.pool);
1098
1099        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1100        while let Some(row) = rows.next().await {
1101            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1102            let ended_at = ended_at.assume_utc();
1103            let duration = Duration::from_millis(duration_millis as u64);
1104            let started_at = ended_at - duration;
1105            let project_time_periods = time_periods.entry(project_id).or_default();
1106
1107            if let Some(prev_duration) = project_time_periods.last_mut() {
1108                if started_at <= prev_duration.end + COALESCE_THRESHOLD
1109                    && ended_at >= prev_duration.start
1110                {
1111                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
1112                } else {
1113                    project_time_periods.push(UserActivityPeriod {
1114                        project_id,
1115                        start: started_at,
1116                        end: ended_at,
1117                        extensions: Default::default(),
1118                    });
1119                }
1120            } else {
1121                project_time_periods.push(UserActivityPeriod {
1122                    project_id,
1123                    start: started_at,
1124                    end: ended_at,
1125                    extensions: Default::default(),
1126                });
1127            }
1128
1129            if let Some((extension, extension_count)) = extension.zip(extension_count) {
1130                project_time_periods
1131                    .last_mut()
1132                    .unwrap()
1133                    .extensions
1134                    .insert(extension, extension_count as usize);
1135            }
1136        }
1137
1138        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1139        durations.sort_unstable_by_key(|duration| duration.start);
1140        Ok(durations)
1141    }
1142
1143    // contacts
1144
1145    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1146        let query = "
1147            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1148            FROM contacts
1149            WHERE user_id_a = $1 OR user_id_b = $1;
1150        ";
1151
1152        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1153            .bind(user_id)
1154            .fetch(&self.pool);
1155
1156        let mut contacts = Vec::new();
1157        while let Some(row) = rows.next().await {
1158            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1159
1160            if user_id_a == user_id {
1161                if accepted {
1162                    contacts.push(Contact::Accepted {
1163                        user_id: user_id_b,
1164                        should_notify: should_notify && a_to_b,
1165                    });
1166                } else if a_to_b {
1167                    contacts.push(Contact::Outgoing { user_id: user_id_b })
1168                } else {
1169                    contacts.push(Contact::Incoming {
1170                        user_id: user_id_b,
1171                        should_notify,
1172                    });
1173                }
1174            } else if accepted {
1175                contacts.push(Contact::Accepted {
1176                    user_id: user_id_a,
1177                    should_notify: should_notify && !a_to_b,
1178                });
1179            } else if a_to_b {
1180                contacts.push(Contact::Incoming {
1181                    user_id: user_id_a,
1182                    should_notify,
1183                });
1184            } else {
1185                contacts.push(Contact::Outgoing { user_id: user_id_a });
1186            }
1187        }
1188
1189        contacts.sort_unstable_by_key(|contact| contact.user_id());
1190
1191        Ok(contacts)
1192    }
1193
1194    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1195        let (id_a, id_b) = if user_id_1 < user_id_2 {
1196            (user_id_1, user_id_2)
1197        } else {
1198            (user_id_2, user_id_1)
1199        };
1200
1201        let query = "
1202            SELECT 1 FROM contacts
1203            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1204            LIMIT 1
1205        ";
1206        Ok(sqlx::query_scalar::<_, i32>(query)
1207            .bind(id_a.0)
1208            .bind(id_b.0)
1209            .fetch_optional(&self.pool)
1210            .await?
1211            .is_some())
1212    }
1213
1214    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1215        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1216            (sender_id, receiver_id, true)
1217        } else {
1218            (receiver_id, sender_id, false)
1219        };
1220        let query = "
1221            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1222            VALUES ($1, $2, $3, 'f', 't')
1223            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1224            SET
1225                accepted = 't',
1226                should_notify = 'f'
1227            WHERE
1228                NOT contacts.accepted AND
1229                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1230                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1231        ";
1232        let result = sqlx::query(query)
1233            .bind(id_a.0)
1234            .bind(id_b.0)
1235            .bind(a_to_b)
1236            .execute(&self.pool)
1237            .await?;
1238
1239        if result.rows_affected() == 1 {
1240            Ok(())
1241        } else {
1242            Err(anyhow!("contact already requested"))?
1243        }
1244    }
1245
1246    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1247        let (id_a, id_b) = if responder_id < requester_id {
1248            (responder_id, requester_id)
1249        } else {
1250            (requester_id, responder_id)
1251        };
1252        let query = "
1253            DELETE FROM contacts
1254            WHERE user_id_a = $1 AND user_id_b = $2;
1255        ";
1256        let result = sqlx::query(query)
1257            .bind(id_a.0)
1258            .bind(id_b.0)
1259            .execute(&self.pool)
1260            .await?;
1261
1262        if result.rows_affected() == 1 {
1263            Ok(())
1264        } else {
1265            Err(anyhow!("no such contact"))?
1266        }
1267    }
1268
1269    async fn dismiss_contact_notification(
1270        &self,
1271        user_id: UserId,
1272        contact_user_id: UserId,
1273    ) -> Result<()> {
1274        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1275            (user_id, contact_user_id, true)
1276        } else {
1277            (contact_user_id, user_id, false)
1278        };
1279
1280        let query = "
1281            UPDATE contacts
1282            SET should_notify = 'f'
1283            WHERE
1284                user_id_a = $1 AND user_id_b = $2 AND
1285                (
1286                    (a_to_b = $3 AND accepted) OR
1287                    (a_to_b != $3 AND NOT accepted)
1288                );
1289        ";
1290
1291        let result = sqlx::query(query)
1292            .bind(id_a.0)
1293            .bind(id_b.0)
1294            .bind(a_to_b)
1295            .execute(&self.pool)
1296            .await?;
1297
1298        if result.rows_affected() == 0 {
1299            Err(anyhow!("no such contact request"))?;
1300        }
1301
1302        Ok(())
1303    }
1304
1305    async fn respond_to_contact_request(
1306        &self,
1307        responder_id: UserId,
1308        requester_id: UserId,
1309        accept: bool,
1310    ) -> Result<()> {
1311        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1312            (responder_id, requester_id, false)
1313        } else {
1314            (requester_id, responder_id, true)
1315        };
1316        let result = if accept {
1317            let query = "
1318                UPDATE contacts
1319                SET accepted = 't', should_notify = 't'
1320                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1321            ";
1322            sqlx::query(query)
1323                .bind(id_a.0)
1324                .bind(id_b.0)
1325                .bind(a_to_b)
1326                .execute(&self.pool)
1327                .await?
1328        } else {
1329            let query = "
1330                DELETE FROM contacts
1331                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1332            ";
1333            sqlx::query(query)
1334                .bind(id_a.0)
1335                .bind(id_b.0)
1336                .bind(a_to_b)
1337                .execute(&self.pool)
1338                .await?
1339        };
1340        if result.rows_affected() == 1 {
1341            Ok(())
1342        } else {
1343            Err(anyhow!("no such contact request"))?
1344        }
1345    }
1346
1347    // access tokens
1348
1349    async fn create_access_token_hash(
1350        &self,
1351        user_id: UserId,
1352        access_token_hash: &str,
1353        max_access_token_count: usize,
1354    ) -> Result<()> {
1355        let insert_query = "
1356            INSERT INTO access_tokens (user_id, hash)
1357            VALUES ($1, $2);
1358        ";
1359        let cleanup_query = "
1360            DELETE FROM access_tokens
1361            WHERE id IN (
1362                SELECT id from access_tokens
1363                WHERE user_id = $1
1364                ORDER BY id DESC
1365                OFFSET $3
1366            )
1367        ";
1368
1369        let mut tx = self.pool.begin().await?;
1370        sqlx::query(insert_query)
1371            .bind(user_id.0)
1372            .bind(access_token_hash)
1373            .execute(&mut tx)
1374            .await?;
1375        sqlx::query(cleanup_query)
1376            .bind(user_id.0)
1377            .bind(access_token_hash)
1378            .bind(max_access_token_count as i32)
1379            .execute(&mut tx)
1380            .await?;
1381        Ok(tx.commit().await?)
1382    }
1383
1384    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1385        let query = "
1386            SELECT hash
1387            FROM access_tokens
1388            WHERE user_id = $1
1389            ORDER BY id DESC
1390        ";
1391        Ok(sqlx::query_scalar(query)
1392            .bind(user_id.0)
1393            .fetch_all(&self.pool)
1394            .await?)
1395    }
1396
1397    // orgs
1398
1399    #[allow(unused)] // Help rust-analyzer
1400    #[cfg(any(test, feature = "seed-support"))]
1401    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1402        let query = "
1403            SELECT *
1404            FROM orgs
1405            WHERE slug = $1
1406        ";
1407        Ok(sqlx::query_as(query)
1408            .bind(slug)
1409            .fetch_optional(&self.pool)
1410            .await?)
1411    }
1412
1413    #[cfg(any(test, feature = "seed-support"))]
1414    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1415        let query = "
1416            INSERT INTO orgs (name, slug)
1417            VALUES ($1, $2)
1418            RETURNING id
1419        ";
1420        Ok(sqlx::query_scalar(query)
1421            .bind(name)
1422            .bind(slug)
1423            .fetch_one(&self.pool)
1424            .await
1425            .map(OrgId)?)
1426    }
1427
1428    #[cfg(any(test, feature = "seed-support"))]
1429    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1430        let query = "
1431            INSERT INTO org_memberships (org_id, user_id, admin)
1432            VALUES ($1, $2, $3)
1433            ON CONFLICT DO NOTHING
1434        ";
1435        Ok(sqlx::query(query)
1436            .bind(org_id.0)
1437            .bind(user_id.0)
1438            .bind(is_admin)
1439            .execute(&self.pool)
1440            .await
1441            .map(drop)?)
1442    }
1443
1444    // channels
1445
1446    #[cfg(any(test, feature = "seed-support"))]
1447    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1448        let query = "
1449            INSERT INTO channels (owner_id, owner_is_user, name)
1450            VALUES ($1, false, $2)
1451            RETURNING id
1452        ";
1453        Ok(sqlx::query_scalar(query)
1454            .bind(org_id.0)
1455            .bind(name)
1456            .fetch_one(&self.pool)
1457            .await
1458            .map(ChannelId)?)
1459    }
1460
1461    #[allow(unused)] // Help rust-analyzer
1462    #[cfg(any(test, feature = "seed-support"))]
1463    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1464        let query = "
1465            SELECT *
1466            FROM channels
1467            WHERE
1468                channels.owner_is_user = false AND
1469                channels.owner_id = $1
1470        ";
1471        Ok(sqlx::query_as(query)
1472            .bind(org_id.0)
1473            .fetch_all(&self.pool)
1474            .await?)
1475    }
1476
1477    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1478        let query = "
1479            SELECT
1480                channels.*
1481            FROM
1482                channel_memberships, channels
1483            WHERE
1484                channel_memberships.user_id = $1 AND
1485                channel_memberships.channel_id = channels.id
1486        ";
1487        Ok(sqlx::query_as(query)
1488            .bind(user_id.0)
1489            .fetch_all(&self.pool)
1490            .await?)
1491    }
1492
1493    async fn can_user_access_channel(
1494        &self,
1495        user_id: UserId,
1496        channel_id: ChannelId,
1497    ) -> Result<bool> {
1498        let query = "
1499            SELECT id
1500            FROM channel_memberships
1501            WHERE user_id = $1 AND channel_id = $2
1502            LIMIT 1
1503        ";
1504        Ok(sqlx::query_scalar::<_, i32>(query)
1505            .bind(user_id.0)
1506            .bind(channel_id.0)
1507            .fetch_optional(&self.pool)
1508            .await
1509            .map(|e| e.is_some())?)
1510    }
1511
1512    #[cfg(any(test, feature = "seed-support"))]
1513    async fn add_channel_member(
1514        &self,
1515        channel_id: ChannelId,
1516        user_id: UserId,
1517        is_admin: bool,
1518    ) -> Result<()> {
1519        let query = "
1520            INSERT INTO channel_memberships (channel_id, user_id, admin)
1521            VALUES ($1, $2, $3)
1522            ON CONFLICT DO NOTHING
1523        ";
1524        Ok(sqlx::query(query)
1525            .bind(channel_id.0)
1526            .bind(user_id.0)
1527            .bind(is_admin)
1528            .execute(&self.pool)
1529            .await
1530            .map(drop)?)
1531    }
1532
1533    // messages
1534
1535    async fn create_channel_message(
1536        &self,
1537        channel_id: ChannelId,
1538        sender_id: UserId,
1539        body: &str,
1540        timestamp: OffsetDateTime,
1541        nonce: u128,
1542    ) -> Result<MessageId> {
1543        let query = "
1544            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1545            VALUES ($1, $2, $3, $4, $5)
1546            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1547            RETURNING id
1548        ";
1549        Ok(sqlx::query_scalar(query)
1550            .bind(channel_id.0)
1551            .bind(sender_id.0)
1552            .bind(body)
1553            .bind(timestamp)
1554            .bind(Uuid::from_u128(nonce))
1555            .fetch_one(&self.pool)
1556            .await
1557            .map(MessageId)?)
1558    }
1559
1560    async fn get_channel_messages(
1561        &self,
1562        channel_id: ChannelId,
1563        count: usize,
1564        before_id: Option<MessageId>,
1565    ) -> Result<Vec<ChannelMessage>> {
1566        let query = r#"
1567            SELECT * FROM (
1568                SELECT
1569                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1570                FROM
1571                    channel_messages
1572                WHERE
1573                    channel_id = $1 AND
1574                    id < $2
1575                ORDER BY id DESC
1576                LIMIT $3
1577            ) as recent_messages
1578            ORDER BY id ASC
1579        "#;
1580        Ok(sqlx::query_as(query)
1581            .bind(channel_id.0)
1582            .bind(before_id.unwrap_or(MessageId::MAX))
1583            .bind(count as i64)
1584            .fetch_all(&self.pool)
1585            .await?)
1586    }
1587
1588    #[cfg(test)]
1589    async fn teardown(&self, url: &str) {
1590        use util::ResultExt;
1591
1592        let query = "
1593            SELECT pg_terminate_backend(pg_stat_activity.pid)
1594            FROM pg_stat_activity
1595            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1596        ";
1597        sqlx::query(query).execute(&self.pool).await.log_err();
1598        self.pool.close().await;
1599        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1600            .await
1601            .log_err();
1602    }
1603
1604    #[cfg(test)]
1605    fn as_fake(&self) -> Option<&FakeDb> {
1606        None
1607    }
1608}
1609
1610macro_rules! id_type {
1611    ($name:ident) => {
1612        #[derive(
1613            Clone,
1614            Copy,
1615            Debug,
1616            Default,
1617            PartialEq,
1618            Eq,
1619            PartialOrd,
1620            Ord,
1621            Hash,
1622            sqlx::Type,
1623            Serialize,
1624            Deserialize,
1625        )]
1626        #[sqlx(transparent)]
1627        #[serde(transparent)]
1628        pub struct $name(pub i32);
1629
1630        impl $name {
1631            #[allow(unused)]
1632            pub const MAX: Self = Self(i32::MAX);
1633
1634            #[allow(unused)]
1635            pub fn from_proto(value: u64) -> Self {
1636                Self(value as i32)
1637            }
1638
1639            #[allow(unused)]
1640            pub fn to_proto(self) -> u64 {
1641                self.0 as u64
1642            }
1643        }
1644
1645        impl std::fmt::Display for $name {
1646            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1647                self.0.fmt(f)
1648            }
1649        }
1650    };
1651}
1652
1653id_type!(UserId);
1654#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1655pub struct User {
1656    pub id: UserId,
1657    pub github_login: String,
1658    pub github_user_id: Option<i32>,
1659    pub email_address: Option<String>,
1660    pub admin: bool,
1661    pub invite_code: Option<String>,
1662    pub invite_count: i32,
1663    pub connected_once: bool,
1664}
1665
1666id_type!(ProjectId);
1667#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1668pub struct Project {
1669    pub id: ProjectId,
1670    pub host_user_id: UserId,
1671    pub unregistered: bool,
1672}
1673
1674#[derive(Clone, Debug, PartialEq, Serialize)]
1675pub struct UserActivitySummary {
1676    pub id: UserId,
1677    pub github_login: String,
1678    pub project_activity: Vec<ProjectActivitySummary>,
1679}
1680
1681#[derive(Clone, Debug, PartialEq, Serialize)]
1682pub struct ProjectActivitySummary {
1683    pub id: ProjectId,
1684    pub duration: Duration,
1685    pub max_collaborators: usize,
1686}
1687
1688#[derive(Clone, Debug, PartialEq, Serialize)]
1689pub struct UserActivityPeriod {
1690    pub project_id: ProjectId,
1691    #[serde(with = "time::serde::iso8601")]
1692    pub start: OffsetDateTime,
1693    #[serde(with = "time::serde::iso8601")]
1694    pub end: OffsetDateTime,
1695    pub extensions: HashMap<String, usize>,
1696}
1697
1698id_type!(OrgId);
1699#[derive(FromRow)]
1700pub struct Org {
1701    pub id: OrgId,
1702    pub name: String,
1703    pub slug: String,
1704}
1705
1706id_type!(ChannelId);
1707#[derive(Clone, Debug, FromRow, Serialize)]
1708pub struct Channel {
1709    pub id: ChannelId,
1710    pub name: String,
1711    pub owner_id: i32,
1712    pub owner_is_user: bool,
1713}
1714
1715id_type!(MessageId);
1716#[derive(Clone, Debug, FromRow)]
1717pub struct ChannelMessage {
1718    pub id: MessageId,
1719    pub channel_id: ChannelId,
1720    pub sender_id: UserId,
1721    pub body: String,
1722    pub sent_at: OffsetDateTime,
1723    pub nonce: Uuid,
1724}
1725
1726#[derive(Clone, Debug, PartialEq, Eq)]
1727pub enum Contact {
1728    Accepted {
1729        user_id: UserId,
1730        should_notify: bool,
1731    },
1732    Outgoing {
1733        user_id: UserId,
1734    },
1735    Incoming {
1736        user_id: UserId,
1737        should_notify: bool,
1738    },
1739}
1740
1741impl Contact {
1742    pub fn user_id(&self) -> UserId {
1743        match self {
1744            Contact::Accepted { user_id, .. } => *user_id,
1745            Contact::Outgoing { user_id } => *user_id,
1746            Contact::Incoming { user_id, .. } => *user_id,
1747        }
1748    }
1749}
1750
1751#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1752pub struct IncomingContactRequest {
1753    pub requester_id: UserId,
1754    pub should_notify: bool,
1755}
1756
1757#[derive(Clone, Deserialize)]
1758pub struct Signup {
1759    pub email_address: String,
1760    pub platform_mac: bool,
1761    pub platform_windows: bool,
1762    pub platform_linux: bool,
1763    pub editor_features: Vec<String>,
1764    pub programming_languages: Vec<String>,
1765    pub device_id: Option<String>,
1766}
1767
1768#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1769pub struct WaitlistSummary {
1770    #[sqlx(default)]
1771    pub count: i64,
1772    #[sqlx(default)]
1773    pub linux_count: i64,
1774    #[sqlx(default)]
1775    pub mac_count: i64,
1776    #[sqlx(default)]
1777    pub windows_count: i64,
1778    #[sqlx(default)]
1779    pub unknown_count: i64,
1780}
1781
1782#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1783pub struct Invite {
1784    pub email_address: String,
1785    pub email_confirmation_code: String,
1786}
1787
1788#[derive(Debug, Serialize, Deserialize)]
1789pub struct NewUserParams {
1790    pub github_login: String,
1791    pub github_user_id: i32,
1792    pub invite_count: i32,
1793}
1794
1795#[derive(Debug)]
1796pub struct NewUserResult {
1797    pub user_id: UserId,
1798    pub metrics_id: String,
1799    pub inviting_user_id: Option<UserId>,
1800    pub signup_device_id: Option<String>,
1801}
1802
1803fn random_invite_code() -> String {
1804    nanoid::nanoid!(16)
1805}
1806
1807fn random_email_confirmation_code() -> String {
1808    nanoid::nanoid!(64)
1809}
1810
1811#[cfg(test)]
1812pub use test::*;
1813
1814#[cfg(test)]
1815mod test {
1816    use super::*;
1817    use anyhow::anyhow;
1818    use collections::BTreeMap;
1819    use gpui::executor::Background;
1820    use lazy_static::lazy_static;
1821    use parking_lot::Mutex;
1822    use rand::prelude::*;
1823    use sqlx::{migrate::MigrateDatabase, Postgres};
1824    use std::sync::Arc;
1825    use util::post_inc;
1826
1827    pub struct FakeDb {
1828        background: Arc<Background>,
1829        pub users: Mutex<BTreeMap<UserId, User>>,
1830        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1831        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1832        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1833        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1834        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1835        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1836        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1837        pub contacts: Mutex<Vec<FakeContact>>,
1838        next_channel_message_id: Mutex<i32>,
1839        next_user_id: Mutex<i32>,
1840        next_org_id: Mutex<i32>,
1841        next_channel_id: Mutex<i32>,
1842        next_project_id: Mutex<i32>,
1843    }
1844
1845    #[derive(Debug)]
1846    pub struct FakeContact {
1847        pub requester_id: UserId,
1848        pub responder_id: UserId,
1849        pub accepted: bool,
1850        pub should_notify: bool,
1851    }
1852
1853    impl FakeDb {
1854        pub fn new(background: Arc<Background>) -> Self {
1855            Self {
1856                background,
1857                users: Default::default(),
1858                next_user_id: Mutex::new(0),
1859                projects: Default::default(),
1860                worktree_extensions: Default::default(),
1861                next_project_id: Mutex::new(1),
1862                orgs: Default::default(),
1863                next_org_id: Mutex::new(1),
1864                org_memberships: Default::default(),
1865                channels: Default::default(),
1866                next_channel_id: Mutex::new(1),
1867                channel_memberships: Default::default(),
1868                channel_messages: Default::default(),
1869                next_channel_message_id: Mutex::new(1),
1870                contacts: Default::default(),
1871            }
1872        }
1873    }
1874
1875    #[async_trait]
1876    impl Db for FakeDb {
1877        async fn create_user(
1878            &self,
1879            email_address: &str,
1880            admin: bool,
1881            params: NewUserParams,
1882        ) -> Result<NewUserResult> {
1883            self.background.simulate_random_delay().await;
1884
1885            let mut users = self.users.lock();
1886            let user_id = if let Some(user) = users
1887                .values()
1888                .find(|user| user.github_login == params.github_login)
1889            {
1890                user.id
1891            } else {
1892                let id = post_inc(&mut *self.next_user_id.lock());
1893                let user_id = UserId(id);
1894                users.insert(
1895                    user_id,
1896                    User {
1897                        id: user_id,
1898                        github_login: params.github_login,
1899                        github_user_id: Some(params.github_user_id),
1900                        email_address: Some(email_address.to_string()),
1901                        admin,
1902                        invite_code: None,
1903                        invite_count: 0,
1904                        connected_once: false,
1905                    },
1906                );
1907                user_id
1908            };
1909            Ok(NewUserResult {
1910                user_id,
1911                metrics_id: "the-metrics-id".to_string(),
1912                inviting_user_id: None,
1913                signup_device_id: None,
1914            })
1915        }
1916
1917        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1918            unimplemented!()
1919        }
1920
1921        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1922            unimplemented!()
1923        }
1924
1925        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1926            self.background.simulate_random_delay().await;
1927            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1928        }
1929
1930        async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
1931            Ok("the-metrics-id".to_string())
1932        }
1933
1934        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1935            self.background.simulate_random_delay().await;
1936            let users = self.users.lock();
1937            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1938        }
1939
1940        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1941            unimplemented!()
1942        }
1943
1944        async fn get_user_by_github_account(
1945            &self,
1946            github_login: &str,
1947            github_user_id: Option<i32>,
1948        ) -> Result<Option<User>> {
1949            self.background.simulate_random_delay().await;
1950            if let Some(github_user_id) = github_user_id {
1951                for user in self.users.lock().values_mut() {
1952                    if user.github_user_id == Some(github_user_id) {
1953                        user.github_login = github_login.into();
1954                        return Ok(Some(user.clone()));
1955                    }
1956                    if user.github_login == github_login {
1957                        user.github_user_id = Some(github_user_id);
1958                        return Ok(Some(user.clone()));
1959                    }
1960                }
1961                Ok(None)
1962            } else {
1963                Ok(self
1964                    .users
1965                    .lock()
1966                    .values()
1967                    .find(|user| user.github_login == github_login)
1968                    .cloned())
1969            }
1970        }
1971
1972        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1973            unimplemented!()
1974        }
1975
1976        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1977            self.background.simulate_random_delay().await;
1978            let mut users = self.users.lock();
1979            let mut user = users
1980                .get_mut(&id)
1981                .ok_or_else(|| anyhow!("user not found"))?;
1982            user.connected_once = connected_once;
1983            Ok(())
1984        }
1985
1986        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1987            unimplemented!()
1988        }
1989
1990        // signups
1991
1992        async fn create_signup(&self, _signup: Signup) -> Result<()> {
1993            unimplemented!()
1994        }
1995
1996        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1997            unimplemented!()
1998        }
1999
2000        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
2001            unimplemented!()
2002        }
2003
2004        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2005            unimplemented!()
2006        }
2007
2008        async fn create_user_from_invite(
2009            &self,
2010            _invite: &Invite,
2011            _user: NewUserParams,
2012        ) -> Result<Option<NewUserResult>> {
2013            unimplemented!()
2014        }
2015
2016        // invite codes
2017
2018        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2019            unimplemented!()
2020        }
2021
2022        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2023            self.background.simulate_random_delay().await;
2024            Ok(None)
2025        }
2026
2027        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2028            unimplemented!()
2029        }
2030
2031        async fn create_invite_from_code(
2032            &self,
2033            _code: &str,
2034            _email_address: &str,
2035            _device_id: Option<&str>,
2036        ) -> Result<Invite> {
2037            unimplemented!()
2038        }
2039
2040        // projects
2041
2042        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2043            self.background.simulate_random_delay().await;
2044            if !self.users.lock().contains_key(&host_user_id) {
2045                Err(anyhow!("no such user"))?;
2046            }
2047
2048            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2049            self.projects.lock().insert(
2050                project_id,
2051                Project {
2052                    id: project_id,
2053                    host_user_id,
2054                    unregistered: false,
2055                },
2056            );
2057            Ok(project_id)
2058        }
2059
2060        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2061            self.background.simulate_random_delay().await;
2062            self.projects
2063                .lock()
2064                .get_mut(&project_id)
2065                .ok_or_else(|| anyhow!("no such project"))?
2066                .unregistered = true;
2067            Ok(())
2068        }
2069
2070        async fn update_worktree_extensions(
2071            &self,
2072            project_id: ProjectId,
2073            worktree_id: u64,
2074            extensions: HashMap<String, u32>,
2075        ) -> Result<()> {
2076            self.background.simulate_random_delay().await;
2077            if !self.projects.lock().contains_key(&project_id) {
2078                Err(anyhow!("no such project"))?;
2079            }
2080
2081            for (extension, count) in extensions {
2082                self.worktree_extensions
2083                    .lock()
2084                    .insert((project_id, worktree_id, extension), count);
2085            }
2086
2087            Ok(())
2088        }
2089
2090        async fn get_project_extensions(
2091            &self,
2092            _project_id: ProjectId,
2093        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2094            unimplemented!()
2095        }
2096
2097        async fn record_user_activity(
2098            &self,
2099            _time_period: Range<OffsetDateTime>,
2100            _active_projects: &[(UserId, ProjectId)],
2101        ) -> Result<()> {
2102            unimplemented!()
2103        }
2104
2105        async fn get_active_user_count(
2106            &self,
2107            _time_period: Range<OffsetDateTime>,
2108            _min_duration: Duration,
2109            _only_collaborative: bool,
2110        ) -> Result<usize> {
2111            unimplemented!()
2112        }
2113
2114        async fn get_top_users_activity_summary(
2115            &self,
2116            _time_period: Range<OffsetDateTime>,
2117            _limit: usize,
2118        ) -> Result<Vec<UserActivitySummary>> {
2119            unimplemented!()
2120        }
2121
2122        async fn get_user_activity_timeline(
2123            &self,
2124            _time_period: Range<OffsetDateTime>,
2125            _user_id: UserId,
2126        ) -> Result<Vec<UserActivityPeriod>> {
2127            unimplemented!()
2128        }
2129
2130        // contacts
2131
2132        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2133            self.background.simulate_random_delay().await;
2134            let mut contacts = Vec::new();
2135
2136            for contact in self.contacts.lock().iter() {
2137                if contact.requester_id == id {
2138                    if contact.accepted {
2139                        contacts.push(Contact::Accepted {
2140                            user_id: contact.responder_id,
2141                            should_notify: contact.should_notify,
2142                        });
2143                    } else {
2144                        contacts.push(Contact::Outgoing {
2145                            user_id: contact.responder_id,
2146                        });
2147                    }
2148                } else if contact.responder_id == id {
2149                    if contact.accepted {
2150                        contacts.push(Contact::Accepted {
2151                            user_id: contact.requester_id,
2152                            should_notify: false,
2153                        });
2154                    } else {
2155                        contacts.push(Contact::Incoming {
2156                            user_id: contact.requester_id,
2157                            should_notify: contact.should_notify,
2158                        });
2159                    }
2160                }
2161            }
2162
2163            contacts.sort_unstable_by_key(|contact| contact.user_id());
2164            Ok(contacts)
2165        }
2166
2167        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2168            self.background.simulate_random_delay().await;
2169            Ok(self.contacts.lock().iter().any(|contact| {
2170                contact.accepted
2171                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2172                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2173            }))
2174        }
2175
2176        async fn send_contact_request(
2177            &self,
2178            requester_id: UserId,
2179            responder_id: UserId,
2180        ) -> Result<()> {
2181            self.background.simulate_random_delay().await;
2182            let mut contacts = self.contacts.lock();
2183            for contact in contacts.iter_mut() {
2184                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2185                    if contact.accepted {
2186                        Err(anyhow!("contact already exists"))?;
2187                    } else {
2188                        Err(anyhow!("contact already requested"))?;
2189                    }
2190                }
2191                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2192                    if contact.accepted {
2193                        Err(anyhow!("contact already exists"))?;
2194                    } else {
2195                        contact.accepted = true;
2196                        contact.should_notify = false;
2197                        return Ok(());
2198                    }
2199                }
2200            }
2201            contacts.push(FakeContact {
2202                requester_id,
2203                responder_id,
2204                accepted: false,
2205                should_notify: true,
2206            });
2207            Ok(())
2208        }
2209
2210        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2211            self.background.simulate_random_delay().await;
2212            self.contacts.lock().retain(|contact| {
2213                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2214            });
2215            Ok(())
2216        }
2217
2218        async fn dismiss_contact_notification(
2219            &self,
2220            user_id: UserId,
2221            contact_user_id: UserId,
2222        ) -> Result<()> {
2223            self.background.simulate_random_delay().await;
2224            let mut contacts = self.contacts.lock();
2225            for contact in contacts.iter_mut() {
2226                if contact.requester_id == contact_user_id
2227                    && contact.responder_id == user_id
2228                    && !contact.accepted
2229                {
2230                    contact.should_notify = false;
2231                    return Ok(());
2232                }
2233                if contact.requester_id == user_id
2234                    && contact.responder_id == contact_user_id
2235                    && contact.accepted
2236                {
2237                    contact.should_notify = false;
2238                    return Ok(());
2239                }
2240            }
2241            Err(anyhow!("no such notification"))?
2242        }
2243
2244        async fn respond_to_contact_request(
2245            &self,
2246            responder_id: UserId,
2247            requester_id: UserId,
2248            accept: bool,
2249        ) -> Result<()> {
2250            self.background.simulate_random_delay().await;
2251            let mut contacts = self.contacts.lock();
2252            for (ix, contact) in contacts.iter_mut().enumerate() {
2253                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2254                    if contact.accepted {
2255                        Err(anyhow!("contact already confirmed"))?;
2256                    }
2257                    if accept {
2258                        contact.accepted = true;
2259                        contact.should_notify = true;
2260                    } else {
2261                        contacts.remove(ix);
2262                    }
2263                    return Ok(());
2264                }
2265            }
2266            Err(anyhow!("no such contact request"))?
2267        }
2268
2269        async fn create_access_token_hash(
2270            &self,
2271            _user_id: UserId,
2272            _access_token_hash: &str,
2273            _max_access_token_count: usize,
2274        ) -> Result<()> {
2275            unimplemented!()
2276        }
2277
2278        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2279            unimplemented!()
2280        }
2281
2282        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2283            unimplemented!()
2284        }
2285
2286        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2287            self.background.simulate_random_delay().await;
2288            let mut orgs = self.orgs.lock();
2289            if orgs.values().any(|org| org.slug == slug) {
2290                Err(anyhow!("org already exists"))?
2291            } else {
2292                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2293                orgs.insert(
2294                    org_id,
2295                    Org {
2296                        id: org_id,
2297                        name: name.to_string(),
2298                        slug: slug.to_string(),
2299                    },
2300                );
2301                Ok(org_id)
2302            }
2303        }
2304
2305        async fn add_org_member(
2306            &self,
2307            org_id: OrgId,
2308            user_id: UserId,
2309            is_admin: bool,
2310        ) -> Result<()> {
2311            self.background.simulate_random_delay().await;
2312            if !self.orgs.lock().contains_key(&org_id) {
2313                Err(anyhow!("org does not exist"))?;
2314            }
2315            if !self.users.lock().contains_key(&user_id) {
2316                Err(anyhow!("user does not exist"))?;
2317            }
2318
2319            self.org_memberships
2320                .lock()
2321                .entry((org_id, user_id))
2322                .or_insert(is_admin);
2323            Ok(())
2324        }
2325
2326        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2327            self.background.simulate_random_delay().await;
2328            if !self.orgs.lock().contains_key(&org_id) {
2329                Err(anyhow!("org does not exist"))?;
2330            }
2331
2332            let mut channels = self.channels.lock();
2333            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2334            channels.insert(
2335                channel_id,
2336                Channel {
2337                    id: channel_id,
2338                    name: name.to_string(),
2339                    owner_id: org_id.0,
2340                    owner_is_user: false,
2341                },
2342            );
2343            Ok(channel_id)
2344        }
2345
2346        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2347            self.background.simulate_random_delay().await;
2348            Ok(self
2349                .channels
2350                .lock()
2351                .values()
2352                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2353                .cloned()
2354                .collect())
2355        }
2356
2357        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2358            self.background.simulate_random_delay().await;
2359            let channels = self.channels.lock();
2360            let memberships = self.channel_memberships.lock();
2361            Ok(channels
2362                .values()
2363                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2364                .cloned()
2365                .collect())
2366        }
2367
2368        async fn can_user_access_channel(
2369            &self,
2370            user_id: UserId,
2371            channel_id: ChannelId,
2372        ) -> Result<bool> {
2373            self.background.simulate_random_delay().await;
2374            Ok(self
2375                .channel_memberships
2376                .lock()
2377                .contains_key(&(channel_id, user_id)))
2378        }
2379
2380        async fn add_channel_member(
2381            &self,
2382            channel_id: ChannelId,
2383            user_id: UserId,
2384            is_admin: bool,
2385        ) -> Result<()> {
2386            self.background.simulate_random_delay().await;
2387            if !self.channels.lock().contains_key(&channel_id) {
2388                Err(anyhow!("channel does not exist"))?;
2389            }
2390            if !self.users.lock().contains_key(&user_id) {
2391                Err(anyhow!("user does not exist"))?;
2392            }
2393
2394            self.channel_memberships
2395                .lock()
2396                .entry((channel_id, user_id))
2397                .or_insert(is_admin);
2398            Ok(())
2399        }
2400
2401        async fn create_channel_message(
2402            &self,
2403            channel_id: ChannelId,
2404            sender_id: UserId,
2405            body: &str,
2406            timestamp: OffsetDateTime,
2407            nonce: u128,
2408        ) -> Result<MessageId> {
2409            self.background.simulate_random_delay().await;
2410            if !self.channels.lock().contains_key(&channel_id) {
2411                Err(anyhow!("channel does not exist"))?;
2412            }
2413            if !self.users.lock().contains_key(&sender_id) {
2414                Err(anyhow!("user does not exist"))?;
2415            }
2416
2417            let mut messages = self.channel_messages.lock();
2418            if let Some(message) = messages
2419                .values()
2420                .find(|message| message.nonce.as_u128() == nonce)
2421            {
2422                Ok(message.id)
2423            } else {
2424                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2425                messages.insert(
2426                    message_id,
2427                    ChannelMessage {
2428                        id: message_id,
2429                        channel_id,
2430                        sender_id,
2431                        body: body.to_string(),
2432                        sent_at: timestamp,
2433                        nonce: Uuid::from_u128(nonce),
2434                    },
2435                );
2436                Ok(message_id)
2437            }
2438        }
2439
2440        async fn get_channel_messages(
2441            &self,
2442            channel_id: ChannelId,
2443            count: usize,
2444            before_id: Option<MessageId>,
2445        ) -> Result<Vec<ChannelMessage>> {
2446            self.background.simulate_random_delay().await;
2447            let mut messages = self
2448                .channel_messages
2449                .lock()
2450                .values()
2451                .rev()
2452                .filter(|message| {
2453                    message.channel_id == channel_id
2454                        && message.id < before_id.unwrap_or(MessageId::MAX)
2455                })
2456                .take(count)
2457                .cloned()
2458                .collect::<Vec<_>>();
2459            messages.sort_unstable_by_key(|message| message.id);
2460            Ok(messages)
2461        }
2462
2463        async fn teardown(&self, _: &str) {}
2464
2465        #[cfg(test)]
2466        fn as_fake(&self) -> Option<&FakeDb> {
2467            Some(self)
2468        }
2469    }
2470
2471    pub struct TestDb {
2472        pub db: Option<Arc<dyn Db>>,
2473        pub url: String,
2474    }
2475
2476    impl TestDb {
2477        #[allow(clippy::await_holding_lock)]
2478        pub async fn postgres() -> Self {
2479            lazy_static! {
2480                static ref LOCK: Mutex<()> = Mutex::new(());
2481            }
2482
2483            let _guard = LOCK.lock();
2484            let mut rng = StdRng::from_entropy();
2485            let name = format!("zed-test-{}", rng.gen::<u128>());
2486            let url = format!("postgres://postgres@localhost/{}", name);
2487            Postgres::create_database(&url)
2488                .await
2489                .expect("failed to create test db");
2490            let db = PostgresDb::new(&url, 5).await.unwrap();
2491            db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
2492                .await
2493                .unwrap();
2494            Self {
2495                db: Some(Arc::new(db)),
2496                url,
2497            }
2498        }
2499
2500        pub fn fake(background: Arc<Background>) -> Self {
2501            Self {
2502                db: Some(Arc::new(FakeDb::new(background))),
2503                url: Default::default(),
2504            }
2505        }
2506
2507        pub fn db(&self) -> &Arc<dyn Db> {
2508            self.db.as_ref().unwrap()
2509        }
2510    }
2511
2512    impl Drop for TestDb {
2513        fn drop(&mut self) {
2514            if let Some(db) = self.db.take() {
2515                futures::executor::block_on(db.teardown(&self.url));
2516            }
2517        }
2518    }
2519}