db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use collections::HashMap;
   6use futures::StreamExt;
   7use serde::{Deserialize, Serialize};
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{
  10    migrate::{Migrate as _, Migration, MigrationSource},
  11    types::Uuid,
  12    FromRow, QueryBuilder,
  13};
  14use std::{cmp, ops::Range, path::Path, time::Duration};
  15use time::{OffsetDateTime, PrimitiveDateTime};
  16
  17#[async_trait]
  18pub trait Db: Send + Sync {
  19    async fn create_user(
  20        &self,
  21        email_address: &str,
  22        admin: bool,
  23        params: NewUserParams,
  24    ) -> Result<NewUserResult>;
  25    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  26    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  27    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  28    async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
  29    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  30    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  31    async fn get_user_by_github_account(
  32        &self,
  33        github_login: &str,
  34        github_user_id: Option<i32>,
  35    ) -> Result<Option<User>>;
  36    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  37    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  38    async fn destroy_user(&self, id: UserId) -> Result<()>;
  39
  40    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
  41    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  42    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  43    async fn create_invite_from_code(
  44        &self,
  45        code: &str,
  46        email_address: &str,
  47        device_id: Option<&str>,
  48    ) -> Result<Invite>;
  49
  50    async fn create_signup(&self, signup: Signup) -> Result<()>;
  51    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
  52    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
  53    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
  54    async fn create_user_from_invite(
  55        &self,
  56        invite: &Invite,
  57        user: NewUserParams,
  58    ) -> Result<Option<NewUserResult>>;
  59
  60    /// Registers a new project for the given user.
  61    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  62
  63    /// Unregisters a project for the given project id.
  64    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  65
  66    /// Update file counts by extension for the given project and worktree.
  67    async fn update_worktree_extensions(
  68        &self,
  69        project_id: ProjectId,
  70        worktree_id: u64,
  71        extensions: HashMap<String, u32>,
  72    ) -> Result<()>;
  73
  74    /// Get the file counts on the given project keyed by their worktree and extension.
  75    async fn get_project_extensions(
  76        &self,
  77        project_id: ProjectId,
  78    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  79
  80    /// Record which users have been active in which projects during
  81    /// a given period of time.
  82    async fn record_user_activity(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        active_projects: &[(UserId, ProjectId)],
  86    ) -> Result<()>;
  87
  88    /// Get the number of users who have been active in the given
  89    /// time period for at least the given time duration.
  90    async fn get_active_user_count(
  91        &self,
  92        time_period: Range<OffsetDateTime>,
  93        min_duration: Duration,
  94        only_collaborative: bool,
  95    ) -> Result<usize>;
  96
  97    /// Get the users that have been most active during the given time period,
  98    /// along with the amount of time they have been active in each project.
  99    async fn get_top_users_activity_summary(
 100        &self,
 101        time_period: Range<OffsetDateTime>,
 102        max_user_count: usize,
 103    ) -> Result<Vec<UserActivitySummary>>;
 104
 105    /// Get the project activity for the given user and time period.
 106    async fn get_user_activity_timeline(
 107        &self,
 108        time_period: Range<OffsetDateTime>,
 109        user_id: UserId,
 110    ) -> Result<Vec<UserActivityPeriod>>;
 111
 112    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
 113    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
 114    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 115    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 116    async fn dismiss_contact_notification(
 117        &self,
 118        responder_id: UserId,
 119        requester_id: UserId,
 120    ) -> Result<()>;
 121    async fn respond_to_contact_request(
 122        &self,
 123        responder_id: UserId,
 124        requester_id: UserId,
 125        accept: bool,
 126    ) -> Result<()>;
 127
 128    async fn create_access_token_hash(
 129        &self,
 130        user_id: UserId,
 131        access_token_hash: &str,
 132        max_access_token_count: usize,
 133    ) -> Result<()>;
 134    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 135
 136    #[cfg(any(test, feature = "seed-support"))]
 137    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 138    #[cfg(any(test, feature = "seed-support"))]
 139    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 140    #[cfg(any(test, feature = "seed-support"))]
 141    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 142    #[cfg(any(test, feature = "seed-support"))]
 143    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 144    #[cfg(any(test, feature = "seed-support"))]
 145
 146    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 147    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 148    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 149        -> Result<bool>;
 150
 151    #[cfg(any(test, feature = "seed-support"))]
 152    async fn add_channel_member(
 153        &self,
 154        channel_id: ChannelId,
 155        user_id: UserId,
 156        is_admin: bool,
 157    ) -> Result<()>;
 158    async fn create_channel_message(
 159        &self,
 160        channel_id: ChannelId,
 161        sender_id: UserId,
 162        body: &str,
 163        timestamp: OffsetDateTime,
 164        nonce: u128,
 165    ) -> Result<MessageId>;
 166    async fn get_channel_messages(
 167        &self,
 168        channel_id: ChannelId,
 169        count: usize,
 170        before_id: Option<MessageId>,
 171    ) -> Result<Vec<ChannelMessage>>;
 172
 173    #[cfg(test)]
 174    async fn teardown(&self, url: &str);
 175
 176    #[cfg(test)]
 177    fn as_fake(&self) -> Option<&FakeDb>;
 178}
 179
 180#[cfg(any(test, debug_assertions))]
 181pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
 182    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 183
 184#[cfg(not(any(test, debug_assertions)))]
 185pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
 186
 187pub struct PostgresDb {
 188    pool: sqlx::PgPool,
 189}
 190
 191macro_rules! test_support {
 192    ($self:ident, { $($token:tt)* }) => {{
 193        let body = async {
 194            $($token)*
 195        };
 196
 197        if cfg!(test) {
 198            tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body)
 199        } else {
 200            body.await
 201        }
 202    }};
 203}
 204
 205impl PostgresDb {
 206    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 207        let pool = DbOptions::new()
 208            .max_connections(max_connections)
 209            .connect(url)
 210            .await
 211            .context("failed to connect to postgres database")?;
 212        Ok(Self { pool })
 213    }
 214
 215    pub async fn migrate(
 216        &self,
 217        migrations_path: &Path,
 218        ignore_checksum_mismatch: bool,
 219    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 220        let migrations = MigrationSource::resolve(migrations_path)
 221            .await
 222            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 223
 224        let mut conn = self.pool.acquire().await?;
 225
 226        conn.ensure_migrations_table().await?;
 227        let applied_migrations: HashMap<_, _> = conn
 228            .list_applied_migrations()
 229            .await?
 230            .into_iter()
 231            .map(|m| (m.version, m))
 232            .collect();
 233
 234        let mut new_migrations = Vec::new();
 235        for migration in migrations {
 236            match applied_migrations.get(&migration.version) {
 237                Some(applied_migration) => {
 238                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 239                    {
 240                        Err(anyhow!(
 241                            "checksum mismatch for applied migration {}",
 242                            migration.description
 243                        ))?;
 244                    }
 245                }
 246                None => {
 247                    let elapsed = conn.apply(&migration).await?;
 248                    new_migrations.push((migration, elapsed));
 249                }
 250            }
 251        }
 252
 253        Ok(new_migrations)
 254    }
 255
 256    pub fn fuzzy_like_string(string: &str) -> String {
 257        let mut result = String::with_capacity(string.len() * 2 + 1);
 258        for c in string.chars() {
 259            if c.is_alphanumeric() {
 260                result.push('%');
 261                result.push(c);
 262            }
 263        }
 264        result.push('%');
 265        result
 266    }
 267}
 268
 269#[async_trait]
 270impl Db for PostgresDb {
 271    // users
 272
 273    async fn create_user(
 274        &self,
 275        email_address: &str,
 276        admin: bool,
 277        params: NewUserParams,
 278    ) -> Result<NewUserResult> {
 279        test_support!(self, {
 280            let query = "
 281                INSERT INTO users (email_address, github_login, github_user_id, admin)
 282                VALUES ($1, $2, $3, $4)
 283                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 284                RETURNING id, metrics_id::text
 285            ";
 286
 287            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 288                .bind(email_address)
 289                .bind(params.github_login)
 290                .bind(params.github_user_id)
 291                .bind(admin)
 292                .fetch_one(&self.pool)
 293                .await?;
 294            Ok(NewUserResult {
 295                user_id,
 296                metrics_id,
 297                signup_device_id: None,
 298                inviting_user_id: None,
 299            })
 300        })
 301    }
 302
 303    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 304        test_support!(self, {
 305            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 306            Ok(sqlx::query_as(query)
 307                .bind(limit as i32)
 308                .bind((page * limit) as i32)
 309                .fetch_all(&self.pool)
 310                .await?)
 311        })
 312    }
 313
 314    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 315        test_support!(self, {
 316            let like_string = Self::fuzzy_like_string(name_query);
 317            let query = "
 318                SELECT users.*
 319                FROM users
 320                WHERE github_login ILIKE $1
 321                ORDER BY github_login <-> $2
 322                LIMIT $3
 323            ";
 324            Ok(sqlx::query_as(query)
 325                .bind(like_string)
 326                .bind(name_query)
 327                .bind(limit as i32)
 328                .fetch_all(&self.pool)
 329                .await?)
 330        })
 331    }
 332
 333    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 334        let users = self.get_users_by_ids(vec![id]).await?;
 335        Ok(users.into_iter().next())
 336    }
 337
 338    async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 339        test_support!(self, {
 340            let query = "
 341                SELECT metrics_id::text
 342                FROM users
 343                WHERE id = $1
 344            ";
 345            Ok(sqlx::query_scalar(query)
 346                .bind(id)
 347                .fetch_one(&self.pool)
 348                .await?)
 349        })
 350    }
 351
 352    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 353        test_support!(self, {
 354            let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 355            let query = "
 356                SELECT users.*
 357                FROM users
 358                WHERE users.id = ANY ($1)
 359            ";
 360            Ok(sqlx::query_as(query)
 361                .bind(&ids)
 362                .fetch_all(&self.pool)
 363                .await?)
 364        })
 365    }
 366
 367    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 368        test_support!(self, {
 369            let query = format!(
 370                "
 371                SELECT users.*
 372                FROM users
 373                WHERE invite_count = 0
 374                AND inviter_id IS{} NULL
 375                ",
 376                if invited_by_another_user { " NOT" } else { "" }
 377            );
 378
 379            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 380        })
 381    }
 382
 383    async fn get_user_by_github_account(
 384        &self,
 385        github_login: &str,
 386        github_user_id: Option<i32>,
 387    ) -> Result<Option<User>> {
 388        test_support!(self, {
 389            if let Some(github_user_id) = github_user_id {
 390                let mut user = sqlx::query_as::<_, User>(
 391                    "
 392                    UPDATE users
 393                    SET github_login = $1
 394                    WHERE github_user_id = $2
 395                    RETURNING *
 396                    ",
 397                )
 398                .bind(github_login)
 399                .bind(github_user_id)
 400                .fetch_optional(&self.pool)
 401                .await?;
 402
 403                if user.is_none() {
 404                    user = sqlx::query_as::<_, User>(
 405                        "
 406                        UPDATE users
 407                        SET github_user_id = $1
 408                        WHERE github_login = $2
 409                        RETURNING *
 410                        ",
 411                    )
 412                    .bind(github_user_id)
 413                    .bind(github_login)
 414                    .fetch_optional(&self.pool)
 415                    .await?;
 416                }
 417
 418                Ok(user)
 419            } else {
 420                let user = sqlx::query_as(
 421                    "
 422                    SELECT * FROM users
 423                    WHERE github_login = $1
 424                    LIMIT 1
 425                    ",
 426                )
 427                .bind(github_login)
 428                .fetch_optional(&self.pool)
 429                .await?;
 430                Ok(user)
 431            }
 432        })
 433    }
 434
 435    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 436        test_support!(self, {
 437            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 438            Ok(sqlx::query(query)
 439                .bind(is_admin)
 440                .bind(id.0)
 441                .execute(&self.pool)
 442                .await
 443                .map(drop)?)
 444        })
 445    }
 446
 447    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 448        test_support!(self, {
 449            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 450            Ok(sqlx::query(query)
 451                .bind(connected_once)
 452                .bind(id.0)
 453                .execute(&self.pool)
 454                .await
 455                .map(drop)?)
 456        })
 457    }
 458
 459    async fn destroy_user(&self, id: UserId) -> Result<()> {
 460        test_support!(self, {
 461            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 462            sqlx::query(query)
 463                .bind(id.0)
 464                .execute(&self.pool)
 465                .await
 466                .map(drop)?;
 467            let query = "DELETE FROM users WHERE id = $1;";
 468            Ok(sqlx::query(query)
 469                .bind(id.0)
 470                .execute(&self.pool)
 471                .await
 472                .map(drop)?)
 473        })
 474    }
 475
 476    // signups
 477
 478    async fn create_signup(&self, signup: Signup) -> Result<()> {
 479        test_support!(self, {
 480            sqlx::query(
 481                "
 482                INSERT INTO signups
 483                (
 484                    email_address,
 485                    email_confirmation_code,
 486                    email_confirmation_sent,
 487                    platform_linux,
 488                    platform_mac,
 489                    platform_windows,
 490                    platform_unknown,
 491                    editor_features,
 492                    programming_languages,
 493                    device_id
 494                )
 495                VALUES
 496                    ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
 497                RETURNING id
 498                ",
 499            )
 500            .bind(&signup.email_address)
 501            .bind(&random_email_confirmation_code())
 502            .bind(&signup.platform_linux)
 503            .bind(&signup.platform_mac)
 504            .bind(&signup.platform_windows)
 505            .bind(&signup.editor_features)
 506            .bind(&signup.programming_languages)
 507            .bind(&signup.device_id)
 508            .execute(&self.pool)
 509            .await?;
 510            Ok(())
 511        })
 512    }
 513
 514    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 515        test_support!(self, {
 516            Ok(sqlx::query_as(
 517                "
 518                SELECT
 519                    COUNT(*) as count,
 520                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 521                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 522                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 523                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 524                FROM (
 525                    SELECT *
 526                    FROM signups
 527                    WHERE
 528                        NOT email_confirmation_sent
 529                ) AS unsent
 530                ",
 531            )
 532            .fetch_one(&self.pool)
 533            .await?)
 534        })
 535    }
 536
 537    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 538        test_support!(self, {
 539            Ok(sqlx::query_as(
 540                "
 541                SELECT
 542                    email_address, email_confirmation_code
 543                FROM signups
 544                WHERE
 545                    NOT email_confirmation_sent AND
 546                    (platform_mac OR platform_unknown)
 547                LIMIT $1
 548                ",
 549            )
 550            .bind(count as i32)
 551            .fetch_all(&self.pool)
 552            .await?)
 553        })
 554    }
 555
 556    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 557        test_support!(self, {
 558            sqlx::query(
 559                "
 560                UPDATE signups
 561                SET email_confirmation_sent = 't'
 562                WHERE email_address = ANY ($1)
 563                ",
 564            )
 565            .bind(
 566                &invites
 567                    .iter()
 568                    .map(|s| s.email_address.as_str())
 569                    .collect::<Vec<_>>(),
 570            )
 571            .execute(&self.pool)
 572            .await?;
 573            Ok(())
 574        })
 575    }
 576
 577    async fn create_user_from_invite(
 578        &self,
 579        invite: &Invite,
 580        user: NewUserParams,
 581    ) -> Result<Option<NewUserResult>> {
 582        test_support!(self, {
 583            let mut tx = self.pool.begin().await?;
 584
 585            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 586                i32,
 587                Option<UserId>,
 588                Option<UserId>,
 589                Option<String>,
 590            ) = sqlx::query_as(
 591                "
 592                SELECT id, user_id, inviting_user_id, device_id
 593                FROM signups
 594                WHERE
 595                    email_address = $1 AND
 596                    email_confirmation_code = $2
 597                ",
 598            )
 599            .bind(&invite.email_address)
 600            .bind(&invite.email_confirmation_code)
 601            .fetch_optional(&mut tx)
 602            .await?
 603            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 604
 605            if existing_user_id.is_some() {
 606                return Ok(None);
 607            }
 608
 609            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 610                "
 611                INSERT INTO users
 612                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 613                VALUES
 614                ($1, $2, $3, 'f', $4, $5)
 615                ON CONFLICT (github_login) DO UPDATE SET
 616                    email_address = excluded.email_address,
 617                    github_user_id = excluded.github_user_id,
 618                    admin = excluded.admin
 619                RETURNING id, metrics_id::text
 620                ",
 621            )
 622            .bind(&invite.email_address)
 623            .bind(&user.github_login)
 624            .bind(&user.github_user_id)
 625            .bind(&user.invite_count)
 626            .bind(random_invite_code())
 627            .fetch_one(&mut tx)
 628            .await?;
 629
 630            sqlx::query(
 631                "
 632                UPDATE signups
 633                SET user_id = $1
 634                WHERE id = $2
 635                ",
 636            )
 637            .bind(&user_id)
 638            .bind(&signup_id)
 639            .execute(&mut tx)
 640            .await?;
 641
 642            if let Some(inviting_user_id) = inviting_user_id {
 643                let id: Option<UserId> = sqlx::query_scalar(
 644                    "
 645                    UPDATE users
 646                    SET invite_count = invite_count - 1
 647                    WHERE id = $1 AND invite_count > 0
 648                    RETURNING id
 649                    ",
 650                )
 651                .bind(&inviting_user_id)
 652                .fetch_optional(&mut tx)
 653                .await?;
 654
 655                if id.is_none() {
 656                    Err(Error::Http(
 657                        StatusCode::UNAUTHORIZED,
 658                        "no invites remaining".to_string(),
 659                    ))?;
 660                }
 661
 662                sqlx::query(
 663                    "
 664                    INSERT INTO contacts
 665                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 666                    VALUES
 667                        ($1, $2, 't', 't', 't')
 668                    ON CONFLICT DO NOTHING
 669                    ",
 670                )
 671                .bind(inviting_user_id)
 672                .bind(user_id)
 673                .execute(&mut tx)
 674                .await?;
 675            }
 676
 677            tx.commit().await?;
 678            Ok(Some(NewUserResult {
 679                user_id,
 680                metrics_id,
 681                inviting_user_id,
 682                signup_device_id,
 683            }))
 684        })
 685    }
 686
 687    // invite codes
 688
 689    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 690        test_support!(self, {
 691            let mut tx = self.pool.begin().await?;
 692            if count > 0 {
 693                sqlx::query(
 694                    "
 695                    UPDATE users
 696                    SET invite_code = $1
 697                    WHERE id = $2 AND invite_code IS NULL
 698                ",
 699                )
 700                .bind(random_invite_code())
 701                .bind(id)
 702                .execute(&mut tx)
 703                .await?;
 704            }
 705
 706            sqlx::query(
 707                "
 708                UPDATE users
 709                SET invite_count = $1
 710                WHERE id = $2
 711                ",
 712            )
 713            .bind(count as i32)
 714            .bind(id)
 715            .execute(&mut tx)
 716            .await?;
 717            tx.commit().await?;
 718            Ok(())
 719        })
 720    }
 721
 722    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 723        test_support!(self, {
 724            let result: Option<(String, i32)> = sqlx::query_as(
 725                "
 726                    SELECT invite_code, invite_count
 727                    FROM users
 728                    WHERE id = $1 AND invite_code IS NOT NULL 
 729                ",
 730            )
 731            .bind(id)
 732            .fetch_optional(&self.pool)
 733            .await?;
 734            if let Some((code, count)) = result {
 735                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 736            } else {
 737                Ok(None)
 738            }
 739        })
 740    }
 741
 742    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 743        test_support!(self, {
 744            sqlx::query_as(
 745                "
 746                    SELECT *
 747                    FROM users
 748                    WHERE invite_code = $1
 749                ",
 750            )
 751            .bind(code)
 752            .fetch_optional(&self.pool)
 753            .await?
 754            .ok_or_else(|| {
 755                Error::Http(
 756                    StatusCode::NOT_FOUND,
 757                    "that invite code does not exist".to_string(),
 758                )
 759            })
 760        })
 761    }
 762
 763    async fn create_invite_from_code(
 764        &self,
 765        code: &str,
 766        email_address: &str,
 767        device_id: Option<&str>,
 768    ) -> Result<Invite> {
 769        test_support!(self, {
 770            let mut tx = self.pool.begin().await?;
 771
 772            let existing_user: Option<UserId> = sqlx::query_scalar(
 773                "
 774                SELECT id
 775                FROM users
 776                WHERE email_address = $1
 777                ",
 778            )
 779            .bind(email_address)
 780            .fetch_optional(&mut tx)
 781            .await?;
 782            if existing_user.is_some() {
 783                Err(anyhow!("email address is already in use"))?;
 784            }
 785
 786            let row: Option<(UserId, i32)> = sqlx::query_as(
 787                "
 788                SELECT id, invite_count
 789                FROM users
 790                WHERE invite_code = $1
 791                ",
 792            )
 793            .bind(code)
 794            .fetch_optional(&mut tx)
 795            .await?;
 796
 797            let (inviter_id, invite_count) = match row {
 798                Some(row) => row,
 799                None => Err(Error::Http(
 800                    StatusCode::NOT_FOUND,
 801                    "invite code not found".to_string(),
 802                ))?,
 803            };
 804
 805            if invite_count == 0 {
 806                Err(Error::Http(
 807                    StatusCode::UNAUTHORIZED,
 808                    "no invites remaining".to_string(),
 809                ))?;
 810            }
 811
 812            let email_confirmation_code: String = sqlx::query_scalar(
 813                "
 814                INSERT INTO signups
 815                (
 816                    email_address,
 817                    email_confirmation_code,
 818                    email_confirmation_sent,
 819                    inviting_user_id,
 820                    platform_linux,
 821                    platform_mac,
 822                    platform_windows,
 823                    platform_unknown,
 824                    device_id
 825                )
 826                VALUES
 827                    ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
 828                ON CONFLICT (email_address)
 829                DO UPDATE SET
 830                    inviting_user_id = excluded.inviting_user_id
 831                RETURNING email_confirmation_code
 832                ",
 833            )
 834            .bind(&email_address)
 835            .bind(&random_email_confirmation_code())
 836            .bind(&inviter_id)
 837            .bind(&device_id)
 838            .fetch_one(&mut tx)
 839            .await?;
 840
 841            tx.commit().await?;
 842
 843            Ok(Invite {
 844                email_address: email_address.into(),
 845                email_confirmation_code,
 846            })
 847        })
 848    }
 849
 850    // projects
 851
 852    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 853        test_support!(self, {
 854            Ok(sqlx::query_scalar(
 855                "
 856                INSERT INTO projects(host_user_id)
 857                VALUES ($1)
 858                RETURNING id
 859                ",
 860            )
 861            .bind(host_user_id)
 862            .fetch_one(&self.pool)
 863            .await
 864            .map(ProjectId)?)
 865        })
 866    }
 867
 868    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 869        test_support!(self, {
 870            sqlx::query(
 871                "
 872                UPDATE projects
 873                SET unregistered = 't'
 874                WHERE id = $1
 875                ",
 876            )
 877            .bind(project_id)
 878            .execute(&self.pool)
 879            .await?;
 880            Ok(())
 881        })
 882    }
 883
 884    async fn update_worktree_extensions(
 885        &self,
 886        project_id: ProjectId,
 887        worktree_id: u64,
 888        extensions: HashMap<String, u32>,
 889    ) -> Result<()> {
 890        test_support!(self, {
 891            if extensions.is_empty() {
 892                return Ok(());
 893            }
 894
 895            let mut query = QueryBuilder::new(
 896                "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 897            );
 898            query.push_values(extensions, |mut query, (extension, count)| {
 899                query
 900                    .push_bind(project_id)
 901                    .push_bind(worktree_id as i32)
 902                    .push_bind(extension)
 903                    .push_bind(count as i32);
 904            });
 905            query.push(
 906                "
 907                ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 908                count = excluded.count
 909                ",
 910            );
 911            query.build().execute(&self.pool).await?;
 912
 913            Ok(())
 914        })
 915    }
 916
 917    async fn get_project_extensions(
 918        &self,
 919        project_id: ProjectId,
 920    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 921        test_support!(self, {
 922            #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 923            struct WorktreeExtension {
 924                worktree_id: i32,
 925                extension: String,
 926                count: i32,
 927            }
 928
 929            let query = "
 930                SELECT worktree_id, extension, count
 931                FROM worktree_extensions
 932                WHERE project_id = $1
 933            ";
 934            let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 935                .bind(&project_id)
 936                .fetch_all(&self.pool)
 937                .await?;
 938
 939            let mut extension_counts = HashMap::default();
 940            for count in counts {
 941                extension_counts
 942                    .entry(count.worktree_id as u64)
 943                    .or_insert_with(HashMap::default)
 944                    .insert(count.extension, count.count as usize);
 945            }
 946            Ok(extension_counts)
 947        })
 948    }
 949
 950    async fn record_user_activity(
 951        &self,
 952        time_period: Range<OffsetDateTime>,
 953        projects: &[(UserId, ProjectId)],
 954    ) -> Result<()> {
 955        test_support!(self, {
 956            let query = "
 957                INSERT INTO project_activity_periods
 958                (ended_at, duration_millis, user_id, project_id)
 959                VALUES
 960                ($1, $2, $3, $4);
 961            ";
 962
 963            let mut tx = self.pool.begin().await?;
 964            let duration_millis =
 965                ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 966            for (user_id, project_id) in projects {
 967                sqlx::query(query)
 968                    .bind(time_period.end)
 969                    .bind(duration_millis)
 970                    .bind(user_id)
 971                    .bind(project_id)
 972                    .execute(&mut tx)
 973                    .await?;
 974            }
 975            tx.commit().await?;
 976
 977            Ok(())
 978        })
 979    }
 980
 981    async fn get_active_user_count(
 982        &self,
 983        time_period: Range<OffsetDateTime>,
 984        min_duration: Duration,
 985        only_collaborative: bool,
 986    ) -> Result<usize> {
 987        test_support!(self, {
 988            let mut with_clause = String::new();
 989            with_clause.push_str("WITH\n");
 990            with_clause.push_str(
 991                "
 992                project_durations AS (
 993                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 994                    FROM project_activity_periods
 995                    WHERE $1 < ended_at AND ended_at <= $2
 996                    GROUP BY user_id, project_id
 997                ),
 998                ",
 999            );
1000            with_clause.push_str(
1001                "
1002                project_collaborators as (
1003                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1004                    FROM project_durations
1005                    GROUP BY project_id
1006                ),
1007                ",
1008            );
1009
1010            if only_collaborative {
1011                with_clause.push_str(
1012                    "
1013                    user_durations AS (
1014                        SELECT user_id, SUM(project_duration) as total_duration
1015                        FROM project_durations, project_collaborators
1016                        WHERE
1017                            project_durations.project_id = project_collaborators.project_id AND
1018                            max_collaborators > 1
1019                        GROUP BY user_id
1020                        ORDER BY total_duration DESC
1021                        LIMIT $3
1022                    )
1023                    ",
1024                );
1025            } else {
1026                with_clause.push_str(
1027                    "
1028                    user_durations AS (
1029                        SELECT user_id, SUM(project_duration) as total_duration
1030                        FROM project_durations
1031                        GROUP BY user_id
1032                        ORDER BY total_duration DESC
1033                        LIMIT $3
1034                    )
1035                    ",
1036                );
1037            }
1038
1039            let query = format!(
1040                "
1041                {with_clause}
1042                SELECT count(user_durations.user_id)
1043                FROM user_durations
1044                WHERE user_durations.total_duration >= $3
1045                "
1046            );
1047
1048            let count: i64 = sqlx::query_scalar(&query)
1049                .bind(time_period.start)
1050                .bind(time_period.end)
1051                .bind(min_duration.as_millis() as i64)
1052                .fetch_one(&self.pool)
1053                .await?;
1054            Ok(count as usize)
1055        })
1056    }
1057
1058    async fn get_top_users_activity_summary(
1059        &self,
1060        time_period: Range<OffsetDateTime>,
1061        max_user_count: usize,
1062    ) -> Result<Vec<UserActivitySummary>> {
1063        test_support!(self, {
1064            let query = "
1065                WITH
1066                    project_durations AS (
1067                        SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1068                        FROM project_activity_periods
1069                        WHERE $1 < ended_at AND ended_at <= $2
1070                        GROUP BY user_id, project_id
1071                    ),
1072                    user_durations AS (
1073                        SELECT user_id, SUM(project_duration) as total_duration
1074                        FROM project_durations
1075                        GROUP BY user_id
1076                        ORDER BY total_duration DESC
1077                        LIMIT $3
1078                    ),
1079                    project_collaborators as (
1080                        SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1081                        FROM project_durations
1082                        GROUP BY project_id
1083                    )
1084                SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1085                FROM user_durations, project_durations, project_collaborators, users
1086                WHERE
1087                    user_durations.user_id = project_durations.user_id AND
1088                    user_durations.user_id = users.id AND
1089                    project_durations.project_id = project_collaborators.project_id
1090                ORDER BY total_duration DESC, user_id ASC, project_id ASC
1091            ";
1092
1093            let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1094                .bind(time_period.start)
1095                .bind(time_period.end)
1096                .bind(max_user_count as i32)
1097                .fetch(&self.pool);
1098
1099            let mut result = Vec::<UserActivitySummary>::new();
1100            while let Some(row) = rows.next().await {
1101                let (user_id, github_login, project_id, duration_millis, project_collaborators) =
1102                    row?;
1103                let project_id = project_id;
1104                let duration = Duration::from_millis(duration_millis as u64);
1105                let project_activity = ProjectActivitySummary {
1106                    id: project_id,
1107                    duration,
1108                    max_collaborators: project_collaborators as usize,
1109                };
1110                if let Some(last_summary) = result.last_mut() {
1111                    if last_summary.id == user_id {
1112                        last_summary.project_activity.push(project_activity);
1113                        continue;
1114                    }
1115                }
1116                result.push(UserActivitySummary {
1117                    id: user_id,
1118                    project_activity: vec![project_activity],
1119                    github_login,
1120                });
1121            }
1122
1123            Ok(result)
1124        })
1125    }
1126
1127    async fn get_user_activity_timeline(
1128        &self,
1129        time_period: Range<OffsetDateTime>,
1130        user_id: UserId,
1131    ) -> Result<Vec<UserActivityPeriod>> {
1132        test_support!(self, {
1133            const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1134
1135            let query = "
1136                SELECT
1137                    project_activity_periods.ended_at,
1138                    project_activity_periods.duration_millis,
1139                    project_activity_periods.project_id,
1140                    worktree_extensions.extension,
1141                    worktree_extensions.count
1142                FROM project_activity_periods
1143                LEFT OUTER JOIN
1144                    worktree_extensions
1145                ON
1146                    project_activity_periods.project_id = worktree_extensions.project_id
1147                WHERE
1148                    project_activity_periods.user_id = $1 AND
1149                    $2 < project_activity_periods.ended_at AND
1150                    project_activity_periods.ended_at <= $3
1151                ORDER BY project_activity_periods.id ASC
1152            ";
1153
1154            let mut rows = sqlx::query_as::<
1155                _,
1156                (
1157                    PrimitiveDateTime,
1158                    i32,
1159                    ProjectId,
1160                    Option<String>,
1161                    Option<i32>,
1162                ),
1163            >(query)
1164            .bind(user_id)
1165            .bind(time_period.start)
1166            .bind(time_period.end)
1167            .fetch(&self.pool);
1168
1169            let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1170            while let Some(row) = rows.next().await {
1171                let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1172                let ended_at = ended_at.assume_utc();
1173                let duration = Duration::from_millis(duration_millis as u64);
1174                let started_at = ended_at - duration;
1175                let project_time_periods = time_periods.entry(project_id).or_default();
1176
1177                if let Some(prev_duration) = project_time_periods.last_mut() {
1178                    if started_at <= prev_duration.end + COALESCE_THRESHOLD
1179                        && ended_at >= prev_duration.start
1180                    {
1181                        prev_duration.end = cmp::max(prev_duration.end, ended_at);
1182                    } else {
1183                        project_time_periods.push(UserActivityPeriod {
1184                            project_id,
1185                            start: started_at,
1186                            end: ended_at,
1187                            extensions: Default::default(),
1188                        });
1189                    }
1190                } else {
1191                    project_time_periods.push(UserActivityPeriod {
1192                        project_id,
1193                        start: started_at,
1194                        end: ended_at,
1195                        extensions: Default::default(),
1196                    });
1197                }
1198
1199                if let Some((extension, extension_count)) = extension.zip(extension_count) {
1200                    project_time_periods
1201                        .last_mut()
1202                        .unwrap()
1203                        .extensions
1204                        .insert(extension, extension_count as usize);
1205                }
1206            }
1207
1208            let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1209            durations.sort_unstable_by_key(|duration| duration.start);
1210            Ok(durations)
1211        })
1212    }
1213
1214    // contacts
1215
1216    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1217        test_support!(self, {
1218            let query = "
1219                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1220                FROM contacts
1221                WHERE user_id_a = $1 OR user_id_b = $1;
1222            ";
1223
1224            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1225                .bind(user_id)
1226                .fetch(&self.pool);
1227
1228            let mut contacts = Vec::new();
1229            while let Some(row) = rows.next().await {
1230                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1231
1232                if user_id_a == user_id {
1233                    if accepted {
1234                        contacts.push(Contact::Accepted {
1235                            user_id: user_id_b,
1236                            should_notify: should_notify && a_to_b,
1237                        });
1238                    } else if a_to_b {
1239                        contacts.push(Contact::Outgoing { user_id: user_id_b })
1240                    } else {
1241                        contacts.push(Contact::Incoming {
1242                            user_id: user_id_b,
1243                            should_notify,
1244                        });
1245                    }
1246                } else if accepted {
1247                    contacts.push(Contact::Accepted {
1248                        user_id: user_id_a,
1249                        should_notify: should_notify && !a_to_b,
1250                    });
1251                } else if a_to_b {
1252                    contacts.push(Contact::Incoming {
1253                        user_id: user_id_a,
1254                        should_notify,
1255                    });
1256                } else {
1257                    contacts.push(Contact::Outgoing { user_id: user_id_a });
1258                }
1259            }
1260
1261            contacts.sort_unstable_by_key(|contact| contact.user_id());
1262
1263            Ok(contacts)
1264        })
1265    }
1266
1267    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1268        test_support!(self, {
1269            let (id_a, id_b) = if user_id_1 < user_id_2 {
1270                (user_id_1, user_id_2)
1271            } else {
1272                (user_id_2, user_id_1)
1273            };
1274
1275            let query = "
1276                SELECT 1 FROM contacts
1277                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1278                LIMIT 1
1279            ";
1280            Ok(sqlx::query_scalar::<_, i32>(query)
1281                .bind(id_a.0)
1282                .bind(id_b.0)
1283                .fetch_optional(&self.pool)
1284                .await?
1285                .is_some())
1286        })
1287    }
1288
1289    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1290        test_support!(self, {
1291            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1292                (sender_id, receiver_id, true)
1293            } else {
1294                (receiver_id, sender_id, false)
1295            };
1296            let query = "
1297                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1298                VALUES ($1, $2, $3, 'f', 't')
1299                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1300                SET
1301                    accepted = 't',
1302                    should_notify = 'f'
1303                WHERE
1304                    NOT contacts.accepted AND
1305                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1306                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1307            ";
1308            let result = sqlx::query(query)
1309                .bind(id_a.0)
1310                .bind(id_b.0)
1311                .bind(a_to_b)
1312                .execute(&self.pool)
1313                .await?;
1314
1315            if result.rows_affected() == 1 {
1316                Ok(())
1317            } else {
1318                Err(anyhow!("contact already requested"))?
1319            }
1320        })
1321    }
1322
1323    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1324        test_support!(self, {
1325            let (id_a, id_b) = if responder_id < requester_id {
1326                (responder_id, requester_id)
1327            } else {
1328                (requester_id, responder_id)
1329            };
1330            let query = "
1331                DELETE FROM contacts
1332                WHERE user_id_a = $1 AND user_id_b = $2;
1333            ";
1334            let result = sqlx::query(query)
1335                .bind(id_a.0)
1336                .bind(id_b.0)
1337                .execute(&self.pool)
1338                .await?;
1339
1340            if result.rows_affected() == 1 {
1341                Ok(())
1342            } else {
1343                Err(anyhow!("no such contact"))?
1344            }
1345        })
1346    }
1347
1348    async fn dismiss_contact_notification(
1349        &self,
1350        user_id: UserId,
1351        contact_user_id: UserId,
1352    ) -> Result<()> {
1353        test_support!(self, {
1354            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1355                (user_id, contact_user_id, true)
1356            } else {
1357                (contact_user_id, user_id, false)
1358            };
1359
1360            let query = "
1361                UPDATE contacts
1362                SET should_notify = 'f'
1363                WHERE
1364                    user_id_a = $1 AND user_id_b = $2 AND
1365                    (
1366                        (a_to_b = $3 AND accepted) OR
1367                        (a_to_b != $3 AND NOT accepted)
1368                    );
1369            ";
1370
1371            let result = sqlx::query(query)
1372                .bind(id_a.0)
1373                .bind(id_b.0)
1374                .bind(a_to_b)
1375                .execute(&self.pool)
1376                .await?;
1377
1378            if result.rows_affected() == 0 {
1379                Err(anyhow!("no such contact request"))?;
1380            }
1381
1382            Ok(())
1383        })
1384    }
1385
1386    async fn respond_to_contact_request(
1387        &self,
1388        responder_id: UserId,
1389        requester_id: UserId,
1390        accept: bool,
1391    ) -> Result<()> {
1392        test_support!(self, {
1393            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1394                (responder_id, requester_id, false)
1395            } else {
1396                (requester_id, responder_id, true)
1397            };
1398            let result = if accept {
1399                let query = "
1400                    UPDATE contacts
1401                    SET accepted = 't', should_notify = 't'
1402                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1403                ";
1404                sqlx::query(query)
1405                    .bind(id_a.0)
1406                    .bind(id_b.0)
1407                    .bind(a_to_b)
1408                    .execute(&self.pool)
1409                    .await?
1410            } else {
1411                let query = "
1412                    DELETE FROM contacts
1413                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1414                ";
1415                sqlx::query(query)
1416                    .bind(id_a.0)
1417                    .bind(id_b.0)
1418                    .bind(a_to_b)
1419                    .execute(&self.pool)
1420                    .await?
1421            };
1422            if result.rows_affected() == 1 {
1423                Ok(())
1424            } else {
1425                Err(anyhow!("no such contact request"))?
1426            }
1427        })
1428    }
1429
1430    // access tokens
1431
1432    async fn create_access_token_hash(
1433        &self,
1434        user_id: UserId,
1435        access_token_hash: &str,
1436        max_access_token_count: usize,
1437    ) -> Result<()> {
1438        test_support!(self, {
1439            let insert_query = "
1440                INSERT INTO access_tokens (user_id, hash)
1441                VALUES ($1, $2);
1442            ";
1443            let cleanup_query = "
1444                DELETE FROM access_tokens
1445                WHERE id IN (
1446                    SELECT id from access_tokens
1447                    WHERE user_id = $1
1448                    ORDER BY id DESC
1449                    OFFSET $3
1450                )
1451            ";
1452
1453            let mut tx = self.pool.begin().await?;
1454            sqlx::query(insert_query)
1455                .bind(user_id.0)
1456                .bind(access_token_hash)
1457                .execute(&mut tx)
1458                .await?;
1459            sqlx::query(cleanup_query)
1460                .bind(user_id.0)
1461                .bind(access_token_hash)
1462                .bind(max_access_token_count as i32)
1463                .execute(&mut tx)
1464                .await?;
1465            Ok(tx.commit().await?)
1466        })
1467    }
1468
1469    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1470        test_support!(self, {
1471            let query = "
1472                SELECT hash
1473                FROM access_tokens
1474                WHERE user_id = $1
1475                ORDER BY id DESC
1476            ";
1477            Ok(sqlx::query_scalar(query)
1478                .bind(user_id.0)
1479                .fetch_all(&self.pool)
1480                .await?)
1481        })
1482    }
1483
1484    // orgs
1485
1486    #[allow(unused)] // Help rust-analyzer
1487    #[cfg(any(test, feature = "seed-support"))]
1488    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1489        test_support!(self, {
1490            let query = "
1491                SELECT *
1492                FROM orgs
1493                WHERE slug = $1
1494            ";
1495            Ok(sqlx::query_as(query)
1496                .bind(slug)
1497                .fetch_optional(&self.pool)
1498                .await?)
1499        })
1500    }
1501
1502    #[cfg(any(test, feature = "seed-support"))]
1503    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1504        test_support!(self, {
1505            let query = "
1506                INSERT INTO orgs (name, slug)
1507                VALUES ($1, $2)
1508                RETURNING id
1509            ";
1510            Ok(sqlx::query_scalar(query)
1511                .bind(name)
1512                .bind(slug)
1513                .fetch_one(&self.pool)
1514                .await
1515                .map(OrgId)?)
1516        })
1517    }
1518
1519    #[cfg(any(test, feature = "seed-support"))]
1520    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1521        test_support!(self, {
1522            let query = "
1523                INSERT INTO org_memberships (org_id, user_id, admin)
1524                VALUES ($1, $2, $3)
1525                ON CONFLICT DO NOTHING
1526            ";
1527            Ok(sqlx::query(query)
1528                .bind(org_id.0)
1529                .bind(user_id.0)
1530                .bind(is_admin)
1531                .execute(&self.pool)
1532                .await
1533                .map(drop)?)
1534        })
1535    }
1536
1537    // channels
1538
1539    #[cfg(any(test, feature = "seed-support"))]
1540    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1541        test_support!(self, {
1542            let query = "
1543                INSERT INTO channels (owner_id, owner_is_user, name)
1544                VALUES ($1, false, $2)
1545                RETURNING id
1546            ";
1547            Ok(sqlx::query_scalar(query)
1548                .bind(org_id.0)
1549                .bind(name)
1550                .fetch_one(&self.pool)
1551                .await
1552                .map(ChannelId)?)
1553        })
1554    }
1555
1556    #[allow(unused)] // Help rust-analyzer
1557    #[cfg(any(test, feature = "seed-support"))]
1558    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1559        test_support!(self, {
1560            let query = "
1561                SELECT *
1562                FROM channels
1563                WHERE
1564                    channels.owner_is_user = false AND
1565                    channels.owner_id = $1
1566            ";
1567            Ok(sqlx::query_as(query)
1568                .bind(org_id.0)
1569                .fetch_all(&self.pool)
1570                .await?)
1571        })
1572    }
1573
1574    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1575        test_support!(self, {
1576            let query = "
1577                SELECT
1578                    channels.*
1579                FROM
1580                    channel_memberships, channels
1581                WHERE
1582                    channel_memberships.user_id = $1 AND
1583                    channel_memberships.channel_id = channels.id
1584            ";
1585            Ok(sqlx::query_as(query)
1586                .bind(user_id.0)
1587                .fetch_all(&self.pool)
1588                .await?)
1589        })
1590    }
1591
1592    async fn can_user_access_channel(
1593        &self,
1594        user_id: UserId,
1595        channel_id: ChannelId,
1596    ) -> Result<bool> {
1597        test_support!(self, {
1598            let query = "
1599                SELECT id
1600                FROM channel_memberships
1601                WHERE user_id = $1 AND channel_id = $2
1602                LIMIT 1
1603            ";
1604            Ok(sqlx::query_scalar::<_, i32>(query)
1605                .bind(user_id.0)
1606                .bind(channel_id.0)
1607                .fetch_optional(&self.pool)
1608                .await
1609                .map(|e| e.is_some())?)
1610        })
1611    }
1612
1613    #[cfg(any(test, feature = "seed-support"))]
1614    async fn add_channel_member(
1615        &self,
1616        channel_id: ChannelId,
1617        user_id: UserId,
1618        is_admin: bool,
1619    ) -> Result<()> {
1620        test_support!(self, {
1621            let query = "
1622                INSERT INTO channel_memberships (channel_id, user_id, admin)
1623                VALUES ($1, $2, $3)
1624                ON CONFLICT DO NOTHING
1625            ";
1626            Ok(sqlx::query(query)
1627                .bind(channel_id.0)
1628                .bind(user_id.0)
1629                .bind(is_admin)
1630                .execute(&self.pool)
1631                .await
1632                .map(drop)?)
1633        })
1634    }
1635
1636    // messages
1637
1638    async fn create_channel_message(
1639        &self,
1640        channel_id: ChannelId,
1641        sender_id: UserId,
1642        body: &str,
1643        timestamp: OffsetDateTime,
1644        nonce: u128,
1645    ) -> Result<MessageId> {
1646        test_support!(self, {
1647            let query = "
1648                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1649                VALUES ($1, $2, $3, $4, $5)
1650                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1651                RETURNING id
1652            ";
1653            Ok(sqlx::query_scalar(query)
1654                .bind(channel_id.0)
1655                .bind(sender_id.0)
1656                .bind(body)
1657                .bind(timestamp)
1658                .bind(Uuid::from_u128(nonce))
1659                .fetch_one(&self.pool)
1660                .await
1661                .map(MessageId)?)
1662        })
1663    }
1664
1665    async fn get_channel_messages(
1666        &self,
1667        channel_id: ChannelId,
1668        count: usize,
1669        before_id: Option<MessageId>,
1670    ) -> Result<Vec<ChannelMessage>> {
1671        test_support!(self, {
1672            let query = r#"
1673                SELECT * FROM (
1674                    SELECT
1675                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1676                    FROM
1677                        channel_messages
1678                    WHERE
1679                        channel_id = $1 AND
1680                        id < $2
1681                    ORDER BY id DESC
1682                    LIMIT $3
1683                ) as recent_messages
1684                ORDER BY id ASC
1685            "#;
1686            Ok(sqlx::query_as(query)
1687                .bind(channel_id.0)
1688                .bind(before_id.unwrap_or(MessageId::MAX))
1689                .bind(count as i64)
1690                .fetch_all(&self.pool)
1691                .await?)
1692        })
1693    }
1694
1695    #[cfg(test)]
1696    async fn teardown(&self, url: &str) {
1697        let start = std::time::Instant::now();
1698        eprintln!("tearing down database...");
1699        test_support!(self, {
1700            use util::ResultExt;
1701
1702            let query = "
1703                SELECT pg_terminate_backend(pg_stat_activity.pid)
1704                FROM pg_stat_activity
1705                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1706            ";
1707            sqlx::query(query).execute(&self.pool).await.log_err();
1708            self.pool.close().await;
1709            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1710                .await
1711                .log_err();
1712            eprintln!("tore down database: {:?}", start.elapsed());
1713        })
1714    }
1715
1716    #[cfg(test)]
1717    fn as_fake(&self) -> Option<&FakeDb> {
1718        None
1719    }
1720}
1721
1722macro_rules! id_type {
1723    ($name:ident) => {
1724        #[derive(
1725            Clone,
1726            Copy,
1727            Debug,
1728            Default,
1729            PartialEq,
1730            Eq,
1731            PartialOrd,
1732            Ord,
1733            Hash,
1734            sqlx::Type,
1735            Serialize,
1736            Deserialize,
1737        )]
1738        #[sqlx(transparent)]
1739        #[serde(transparent)]
1740        pub struct $name(pub i32);
1741
1742        impl $name {
1743            #[allow(unused)]
1744            pub const MAX: Self = Self(i32::MAX);
1745
1746            #[allow(unused)]
1747            pub fn from_proto(value: u64) -> Self {
1748                Self(value as i32)
1749            }
1750
1751            #[allow(unused)]
1752            pub fn to_proto(self) -> u64 {
1753                self.0 as u64
1754            }
1755        }
1756
1757        impl std::fmt::Display for $name {
1758            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1759                self.0.fmt(f)
1760            }
1761        }
1762    };
1763}
1764
1765id_type!(UserId);
1766#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1767pub struct User {
1768    pub id: UserId,
1769    pub github_login: String,
1770    pub github_user_id: Option<i32>,
1771    pub email_address: Option<String>,
1772    pub admin: bool,
1773    pub invite_code: Option<String>,
1774    pub invite_count: i32,
1775    pub connected_once: bool,
1776}
1777
1778id_type!(ProjectId);
1779#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1780pub struct Project {
1781    pub id: ProjectId,
1782    pub host_user_id: UserId,
1783    pub unregistered: bool,
1784}
1785
1786#[derive(Clone, Debug, PartialEq, Serialize)]
1787pub struct UserActivitySummary {
1788    pub id: UserId,
1789    pub github_login: String,
1790    pub project_activity: Vec<ProjectActivitySummary>,
1791}
1792
1793#[derive(Clone, Debug, PartialEq, Serialize)]
1794pub struct ProjectActivitySummary {
1795    pub id: ProjectId,
1796    pub duration: Duration,
1797    pub max_collaborators: usize,
1798}
1799
1800#[derive(Clone, Debug, PartialEq, Serialize)]
1801pub struct UserActivityPeriod {
1802    pub project_id: ProjectId,
1803    #[serde(with = "time::serde::iso8601")]
1804    pub start: OffsetDateTime,
1805    #[serde(with = "time::serde::iso8601")]
1806    pub end: OffsetDateTime,
1807    pub extensions: HashMap<String, usize>,
1808}
1809
1810id_type!(OrgId);
1811#[derive(FromRow)]
1812pub struct Org {
1813    pub id: OrgId,
1814    pub name: String,
1815    pub slug: String,
1816}
1817
1818id_type!(ChannelId);
1819#[derive(Clone, Debug, FromRow, Serialize)]
1820pub struct Channel {
1821    pub id: ChannelId,
1822    pub name: String,
1823    pub owner_id: i32,
1824    pub owner_is_user: bool,
1825}
1826
1827id_type!(MessageId);
1828#[derive(Clone, Debug, FromRow)]
1829pub struct ChannelMessage {
1830    pub id: MessageId,
1831    pub channel_id: ChannelId,
1832    pub sender_id: UserId,
1833    pub body: String,
1834    pub sent_at: OffsetDateTime,
1835    pub nonce: Uuid,
1836}
1837
1838#[derive(Clone, Debug, PartialEq, Eq)]
1839pub enum Contact {
1840    Accepted {
1841        user_id: UserId,
1842        should_notify: bool,
1843    },
1844    Outgoing {
1845        user_id: UserId,
1846    },
1847    Incoming {
1848        user_id: UserId,
1849        should_notify: bool,
1850    },
1851}
1852
1853impl Contact {
1854    pub fn user_id(&self) -> UserId {
1855        match self {
1856            Contact::Accepted { user_id, .. } => *user_id,
1857            Contact::Outgoing { user_id } => *user_id,
1858            Contact::Incoming { user_id, .. } => *user_id,
1859        }
1860    }
1861}
1862
1863#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1864pub struct IncomingContactRequest {
1865    pub requester_id: UserId,
1866    pub should_notify: bool,
1867}
1868
1869#[derive(Clone, Deserialize)]
1870pub struct Signup {
1871    pub email_address: String,
1872    pub platform_mac: bool,
1873    pub platform_windows: bool,
1874    pub platform_linux: bool,
1875    pub editor_features: Vec<String>,
1876    pub programming_languages: Vec<String>,
1877    pub device_id: Option<String>,
1878}
1879
1880#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1881pub struct WaitlistSummary {
1882    #[sqlx(default)]
1883    pub count: i64,
1884    #[sqlx(default)]
1885    pub linux_count: i64,
1886    #[sqlx(default)]
1887    pub mac_count: i64,
1888    #[sqlx(default)]
1889    pub windows_count: i64,
1890    #[sqlx(default)]
1891    pub unknown_count: i64,
1892}
1893
1894#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1895pub struct Invite {
1896    pub email_address: String,
1897    pub email_confirmation_code: String,
1898}
1899
1900#[derive(Debug, Serialize, Deserialize)]
1901pub struct NewUserParams {
1902    pub github_login: String,
1903    pub github_user_id: i32,
1904    pub invite_count: i32,
1905}
1906
1907#[derive(Debug)]
1908pub struct NewUserResult {
1909    pub user_id: UserId,
1910    pub metrics_id: String,
1911    pub inviting_user_id: Option<UserId>,
1912    pub signup_device_id: Option<String>,
1913}
1914
1915fn random_invite_code() -> String {
1916    nanoid::nanoid!(16)
1917}
1918
1919fn random_email_confirmation_code() -> String {
1920    nanoid::nanoid!(64)
1921}
1922
1923#[cfg(test)]
1924pub use test::*;
1925
1926#[cfg(test)]
1927mod test {
1928    use super::*;
1929    use anyhow::anyhow;
1930    use collections::BTreeMap;
1931    use gpui::executor::Background;
1932    use lazy_static::lazy_static;
1933    use parking_lot::Mutex;
1934    use rand::prelude::*;
1935    use sqlx::{migrate::MigrateDatabase, Postgres};
1936    use std::sync::Arc;
1937    use util::post_inc;
1938
1939    pub struct FakeDb {
1940        background: Arc<Background>,
1941        pub users: Mutex<BTreeMap<UserId, User>>,
1942        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1943        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1944        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1945        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1946        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1947        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1948        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1949        pub contacts: Mutex<Vec<FakeContact>>,
1950        next_channel_message_id: Mutex<i32>,
1951        next_user_id: Mutex<i32>,
1952        next_org_id: Mutex<i32>,
1953        next_channel_id: Mutex<i32>,
1954        next_project_id: Mutex<i32>,
1955    }
1956
1957    #[derive(Debug)]
1958    pub struct FakeContact {
1959        pub requester_id: UserId,
1960        pub responder_id: UserId,
1961        pub accepted: bool,
1962        pub should_notify: bool,
1963    }
1964
1965    impl FakeDb {
1966        pub fn new(background: Arc<Background>) -> Self {
1967            Self {
1968                background,
1969                users: Default::default(),
1970                next_user_id: Mutex::new(0),
1971                projects: Default::default(),
1972                worktree_extensions: Default::default(),
1973                next_project_id: Mutex::new(1),
1974                orgs: Default::default(),
1975                next_org_id: Mutex::new(1),
1976                org_memberships: Default::default(),
1977                channels: Default::default(),
1978                next_channel_id: Mutex::new(1),
1979                channel_memberships: Default::default(),
1980                channel_messages: Default::default(),
1981                next_channel_message_id: Mutex::new(1),
1982                contacts: Default::default(),
1983            }
1984        }
1985    }
1986
1987    #[async_trait]
1988    impl Db for FakeDb {
1989        async fn create_user(
1990            &self,
1991            email_address: &str,
1992            admin: bool,
1993            params: NewUserParams,
1994        ) -> Result<NewUserResult> {
1995            self.background.simulate_random_delay().await;
1996
1997            let mut users = self.users.lock();
1998            let user_id = if let Some(user) = users
1999                .values()
2000                .find(|user| user.github_login == params.github_login)
2001            {
2002                user.id
2003            } else {
2004                let id = post_inc(&mut *self.next_user_id.lock());
2005                let user_id = UserId(id);
2006                users.insert(
2007                    user_id,
2008                    User {
2009                        id: user_id,
2010                        github_login: params.github_login,
2011                        github_user_id: Some(params.github_user_id),
2012                        email_address: Some(email_address.to_string()),
2013                        admin,
2014                        invite_code: None,
2015                        invite_count: 0,
2016                        connected_once: false,
2017                    },
2018                );
2019                user_id
2020            };
2021            Ok(NewUserResult {
2022                user_id,
2023                metrics_id: "the-metrics-id".to_string(),
2024                inviting_user_id: None,
2025                signup_device_id: None,
2026            })
2027        }
2028
2029        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2030            unimplemented!()
2031        }
2032
2033        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2034            unimplemented!()
2035        }
2036
2037        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2038            self.background.simulate_random_delay().await;
2039            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2040        }
2041
2042        async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
2043            Ok("the-metrics-id".to_string())
2044        }
2045
2046        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2047            self.background.simulate_random_delay().await;
2048            let users = self.users.lock();
2049            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2050        }
2051
2052        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2053            unimplemented!()
2054        }
2055
2056        async fn get_user_by_github_account(
2057            &self,
2058            github_login: &str,
2059            github_user_id: Option<i32>,
2060        ) -> Result<Option<User>> {
2061            self.background.simulate_random_delay().await;
2062            if let Some(github_user_id) = github_user_id {
2063                for user in self.users.lock().values_mut() {
2064                    if user.github_user_id == Some(github_user_id) {
2065                        user.github_login = github_login.into();
2066                        return Ok(Some(user.clone()));
2067                    }
2068                    if user.github_login == github_login {
2069                        user.github_user_id = Some(github_user_id);
2070                        return Ok(Some(user.clone()));
2071                    }
2072                }
2073                Ok(None)
2074            } else {
2075                Ok(self
2076                    .users
2077                    .lock()
2078                    .values()
2079                    .find(|user| user.github_login == github_login)
2080                    .cloned())
2081            }
2082        }
2083
2084        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2085            unimplemented!()
2086        }
2087
2088        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2089            self.background.simulate_random_delay().await;
2090            let mut users = self.users.lock();
2091            let mut user = users
2092                .get_mut(&id)
2093                .ok_or_else(|| anyhow!("user not found"))?;
2094            user.connected_once = connected_once;
2095            Ok(())
2096        }
2097
2098        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2099            unimplemented!()
2100        }
2101
2102        // signups
2103
2104        async fn create_signup(&self, _signup: Signup) -> Result<()> {
2105            unimplemented!()
2106        }
2107
2108        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
2109            unimplemented!()
2110        }
2111
2112        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
2113            unimplemented!()
2114        }
2115
2116        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2117            unimplemented!()
2118        }
2119
2120        async fn create_user_from_invite(
2121            &self,
2122            _invite: &Invite,
2123            _user: NewUserParams,
2124        ) -> Result<Option<NewUserResult>> {
2125            unimplemented!()
2126        }
2127
2128        // invite codes
2129
2130        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2131            unimplemented!()
2132        }
2133
2134        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2135            self.background.simulate_random_delay().await;
2136            Ok(None)
2137        }
2138
2139        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2140            unimplemented!()
2141        }
2142
2143        async fn create_invite_from_code(
2144            &self,
2145            _code: &str,
2146            _email_address: &str,
2147            _device_id: Option<&str>,
2148        ) -> Result<Invite> {
2149            unimplemented!()
2150        }
2151
2152        // projects
2153
2154        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2155            self.background.simulate_random_delay().await;
2156            if !self.users.lock().contains_key(&host_user_id) {
2157                Err(anyhow!("no such user"))?;
2158            }
2159
2160            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2161            self.projects.lock().insert(
2162                project_id,
2163                Project {
2164                    id: project_id,
2165                    host_user_id,
2166                    unregistered: false,
2167                },
2168            );
2169            Ok(project_id)
2170        }
2171
2172        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2173            self.background.simulate_random_delay().await;
2174            self.projects
2175                .lock()
2176                .get_mut(&project_id)
2177                .ok_or_else(|| anyhow!("no such project"))?
2178                .unregistered = true;
2179            Ok(())
2180        }
2181
2182        async fn update_worktree_extensions(
2183            &self,
2184            project_id: ProjectId,
2185            worktree_id: u64,
2186            extensions: HashMap<String, u32>,
2187        ) -> Result<()> {
2188            self.background.simulate_random_delay().await;
2189            if !self.projects.lock().contains_key(&project_id) {
2190                Err(anyhow!("no such project"))?;
2191            }
2192
2193            for (extension, count) in extensions {
2194                self.worktree_extensions
2195                    .lock()
2196                    .insert((project_id, worktree_id, extension), count);
2197            }
2198
2199            Ok(())
2200        }
2201
2202        async fn get_project_extensions(
2203            &self,
2204            _project_id: ProjectId,
2205        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2206            unimplemented!()
2207        }
2208
2209        async fn record_user_activity(
2210            &self,
2211            _time_period: Range<OffsetDateTime>,
2212            _active_projects: &[(UserId, ProjectId)],
2213        ) -> Result<()> {
2214            unimplemented!()
2215        }
2216
2217        async fn get_active_user_count(
2218            &self,
2219            _time_period: Range<OffsetDateTime>,
2220            _min_duration: Duration,
2221            _only_collaborative: bool,
2222        ) -> Result<usize> {
2223            unimplemented!()
2224        }
2225
2226        async fn get_top_users_activity_summary(
2227            &self,
2228            _time_period: Range<OffsetDateTime>,
2229            _limit: usize,
2230        ) -> Result<Vec<UserActivitySummary>> {
2231            unimplemented!()
2232        }
2233
2234        async fn get_user_activity_timeline(
2235            &self,
2236            _time_period: Range<OffsetDateTime>,
2237            _user_id: UserId,
2238        ) -> Result<Vec<UserActivityPeriod>> {
2239            unimplemented!()
2240        }
2241
2242        // contacts
2243
2244        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2245            self.background.simulate_random_delay().await;
2246            let mut contacts = Vec::new();
2247
2248            for contact in self.contacts.lock().iter() {
2249                if contact.requester_id == id {
2250                    if contact.accepted {
2251                        contacts.push(Contact::Accepted {
2252                            user_id: contact.responder_id,
2253                            should_notify: contact.should_notify,
2254                        });
2255                    } else {
2256                        contacts.push(Contact::Outgoing {
2257                            user_id: contact.responder_id,
2258                        });
2259                    }
2260                } else if contact.responder_id == id {
2261                    if contact.accepted {
2262                        contacts.push(Contact::Accepted {
2263                            user_id: contact.requester_id,
2264                            should_notify: false,
2265                        });
2266                    } else {
2267                        contacts.push(Contact::Incoming {
2268                            user_id: contact.requester_id,
2269                            should_notify: contact.should_notify,
2270                        });
2271                    }
2272                }
2273            }
2274
2275            contacts.sort_unstable_by_key(|contact| contact.user_id());
2276            Ok(contacts)
2277        }
2278
2279        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2280            self.background.simulate_random_delay().await;
2281            Ok(self.contacts.lock().iter().any(|contact| {
2282                contact.accepted
2283                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2284                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2285            }))
2286        }
2287
2288        async fn send_contact_request(
2289            &self,
2290            requester_id: UserId,
2291            responder_id: UserId,
2292        ) -> Result<()> {
2293            self.background.simulate_random_delay().await;
2294            let mut contacts = self.contacts.lock();
2295            for contact in contacts.iter_mut() {
2296                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2297                    if contact.accepted {
2298                        Err(anyhow!("contact already exists"))?;
2299                    } else {
2300                        Err(anyhow!("contact already requested"))?;
2301                    }
2302                }
2303                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2304                    if contact.accepted {
2305                        Err(anyhow!("contact already exists"))?;
2306                    } else {
2307                        contact.accepted = true;
2308                        contact.should_notify = false;
2309                        return Ok(());
2310                    }
2311                }
2312            }
2313            contacts.push(FakeContact {
2314                requester_id,
2315                responder_id,
2316                accepted: false,
2317                should_notify: true,
2318            });
2319            Ok(())
2320        }
2321
2322        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2323            self.background.simulate_random_delay().await;
2324            self.contacts.lock().retain(|contact| {
2325                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2326            });
2327            Ok(())
2328        }
2329
2330        async fn dismiss_contact_notification(
2331            &self,
2332            user_id: UserId,
2333            contact_user_id: UserId,
2334        ) -> Result<()> {
2335            self.background.simulate_random_delay().await;
2336            let mut contacts = self.contacts.lock();
2337            for contact in contacts.iter_mut() {
2338                if contact.requester_id == contact_user_id
2339                    && contact.responder_id == user_id
2340                    && !contact.accepted
2341                {
2342                    contact.should_notify = false;
2343                    return Ok(());
2344                }
2345                if contact.requester_id == user_id
2346                    && contact.responder_id == contact_user_id
2347                    && contact.accepted
2348                {
2349                    contact.should_notify = false;
2350                    return Ok(());
2351                }
2352            }
2353            Err(anyhow!("no such notification"))?
2354        }
2355
2356        async fn respond_to_contact_request(
2357            &self,
2358            responder_id: UserId,
2359            requester_id: UserId,
2360            accept: bool,
2361        ) -> Result<()> {
2362            self.background.simulate_random_delay().await;
2363            let mut contacts = self.contacts.lock();
2364            for (ix, contact) in contacts.iter_mut().enumerate() {
2365                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2366                    if contact.accepted {
2367                        Err(anyhow!("contact already confirmed"))?;
2368                    }
2369                    if accept {
2370                        contact.accepted = true;
2371                        contact.should_notify = true;
2372                    } else {
2373                        contacts.remove(ix);
2374                    }
2375                    return Ok(());
2376                }
2377            }
2378            Err(anyhow!("no such contact request"))?
2379        }
2380
2381        async fn create_access_token_hash(
2382            &self,
2383            _user_id: UserId,
2384            _access_token_hash: &str,
2385            _max_access_token_count: usize,
2386        ) -> Result<()> {
2387            unimplemented!()
2388        }
2389
2390        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2391            unimplemented!()
2392        }
2393
2394        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2395            unimplemented!()
2396        }
2397
2398        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2399            self.background.simulate_random_delay().await;
2400            let mut orgs = self.orgs.lock();
2401            if orgs.values().any(|org| org.slug == slug) {
2402                Err(anyhow!("org already exists"))?
2403            } else {
2404                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2405                orgs.insert(
2406                    org_id,
2407                    Org {
2408                        id: org_id,
2409                        name: name.to_string(),
2410                        slug: slug.to_string(),
2411                    },
2412                );
2413                Ok(org_id)
2414            }
2415        }
2416
2417        async fn add_org_member(
2418            &self,
2419            org_id: OrgId,
2420            user_id: UserId,
2421            is_admin: bool,
2422        ) -> Result<()> {
2423            self.background.simulate_random_delay().await;
2424            if !self.orgs.lock().contains_key(&org_id) {
2425                Err(anyhow!("org does not exist"))?;
2426            }
2427            if !self.users.lock().contains_key(&user_id) {
2428                Err(anyhow!("user does not exist"))?;
2429            }
2430
2431            self.org_memberships
2432                .lock()
2433                .entry((org_id, user_id))
2434                .or_insert(is_admin);
2435            Ok(())
2436        }
2437
2438        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2439            self.background.simulate_random_delay().await;
2440            if !self.orgs.lock().contains_key(&org_id) {
2441                Err(anyhow!("org does not exist"))?;
2442            }
2443
2444            let mut channels = self.channels.lock();
2445            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2446            channels.insert(
2447                channel_id,
2448                Channel {
2449                    id: channel_id,
2450                    name: name.to_string(),
2451                    owner_id: org_id.0,
2452                    owner_is_user: false,
2453                },
2454            );
2455            Ok(channel_id)
2456        }
2457
2458        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2459            self.background.simulate_random_delay().await;
2460            Ok(self
2461                .channels
2462                .lock()
2463                .values()
2464                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2465                .cloned()
2466                .collect())
2467        }
2468
2469        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2470            self.background.simulate_random_delay().await;
2471            let channels = self.channels.lock();
2472            let memberships = self.channel_memberships.lock();
2473            Ok(channels
2474                .values()
2475                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2476                .cloned()
2477                .collect())
2478        }
2479
2480        async fn can_user_access_channel(
2481            &self,
2482            user_id: UserId,
2483            channel_id: ChannelId,
2484        ) -> Result<bool> {
2485            self.background.simulate_random_delay().await;
2486            Ok(self
2487                .channel_memberships
2488                .lock()
2489                .contains_key(&(channel_id, user_id)))
2490        }
2491
2492        async fn add_channel_member(
2493            &self,
2494            channel_id: ChannelId,
2495            user_id: UserId,
2496            is_admin: bool,
2497        ) -> Result<()> {
2498            self.background.simulate_random_delay().await;
2499            if !self.channels.lock().contains_key(&channel_id) {
2500                Err(anyhow!("channel does not exist"))?;
2501            }
2502            if !self.users.lock().contains_key(&user_id) {
2503                Err(anyhow!("user does not exist"))?;
2504            }
2505
2506            self.channel_memberships
2507                .lock()
2508                .entry((channel_id, user_id))
2509                .or_insert(is_admin);
2510            Ok(())
2511        }
2512
2513        async fn create_channel_message(
2514            &self,
2515            channel_id: ChannelId,
2516            sender_id: UserId,
2517            body: &str,
2518            timestamp: OffsetDateTime,
2519            nonce: u128,
2520        ) -> Result<MessageId> {
2521            self.background.simulate_random_delay().await;
2522            if !self.channels.lock().contains_key(&channel_id) {
2523                Err(anyhow!("channel does not exist"))?;
2524            }
2525            if !self.users.lock().contains_key(&sender_id) {
2526                Err(anyhow!("user does not exist"))?;
2527            }
2528
2529            let mut messages = self.channel_messages.lock();
2530            if let Some(message) = messages
2531                .values()
2532                .find(|message| message.nonce.as_u128() == nonce)
2533            {
2534                Ok(message.id)
2535            } else {
2536                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2537                messages.insert(
2538                    message_id,
2539                    ChannelMessage {
2540                        id: message_id,
2541                        channel_id,
2542                        sender_id,
2543                        body: body.to_string(),
2544                        sent_at: timestamp,
2545                        nonce: Uuid::from_u128(nonce),
2546                    },
2547                );
2548                Ok(message_id)
2549            }
2550        }
2551
2552        async fn get_channel_messages(
2553            &self,
2554            channel_id: ChannelId,
2555            count: usize,
2556            before_id: Option<MessageId>,
2557        ) -> Result<Vec<ChannelMessage>> {
2558            self.background.simulate_random_delay().await;
2559            let mut messages = self
2560                .channel_messages
2561                .lock()
2562                .values()
2563                .rev()
2564                .filter(|message| {
2565                    message.channel_id == channel_id
2566                        && message.id < before_id.unwrap_or(MessageId::MAX)
2567                })
2568                .take(count)
2569                .cloned()
2570                .collect::<Vec<_>>();
2571            messages.sort_unstable_by_key(|message| message.id);
2572            Ok(messages)
2573        }
2574
2575        async fn teardown(&self, _: &str) {}
2576
2577        #[cfg(test)]
2578        fn as_fake(&self) -> Option<&FakeDb> {
2579            Some(self)
2580        }
2581    }
2582
2583    pub struct TestDb {
2584        pub db: Option<Arc<dyn Db>>,
2585        pub url: String,
2586    }
2587
2588    impl TestDb {
2589        #[allow(clippy::await_holding_lock)]
2590        pub async fn postgres() -> Self {
2591            lazy_static! {
2592                static ref LOCK: Mutex<()> = Mutex::new(());
2593            }
2594
2595            eprintln!("creating database...");
2596            let start = std::time::Instant::now();
2597            let _guard = LOCK.lock();
2598            let mut rng = StdRng::from_entropy();
2599            let name = format!("zed-test-{}", rng.gen::<u128>());
2600            let url = format!("postgres://postgres@localhost:5433/{}", name);
2601            Postgres::create_database(&url)
2602                .await
2603                .expect("failed to create test db");
2604            let db = PostgresDb::new(&url, 5).await.unwrap();
2605            db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
2606                .await
2607                .unwrap();
2608
2609            eprintln!("created database: {:?}", start.elapsed());
2610            Self {
2611                db: Some(Arc::new(db)),
2612                url,
2613            }
2614        }
2615
2616        pub fn fake(background: Arc<Background>) -> Self {
2617            Self {
2618                db: Some(Arc::new(FakeDb::new(background))),
2619                url: Default::default(),
2620            }
2621        }
2622
2623        pub fn db(&self) -> &Arc<dyn Db> {
2624            self.db.as_ref().unwrap()
2625        }
2626    }
2627
2628    impl Drop for TestDb {
2629        fn drop(&mut self) {
2630            if let Some(db) = self.db.take() {
2631                futures::executor::block_on(db.teardown(&self.url));
2632            }
2633        }
2634    }
2635}