db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use collections::HashMap;
   6use futures::StreamExt;
   7use serde::{Deserialize, Serialize};
   8use sqlx::{
   9    migrate::{Migrate as _, Migration, MigrationSource},
  10    types::Uuid,
  11    FromRow, QueryBuilder,
  12};
  13use std::{cmp, ops::Range, path::Path, time::Duration};
  14use time::{OffsetDateTime, PrimitiveDateTime};
  15
  16#[async_trait]
  17pub trait Db: Send + Sync {
  18    async fn create_user(
  19        &self,
  20        email_address: &str,
  21        admin: bool,
  22        params: NewUserParams,
  23    ) -> Result<NewUserResult>;
  24    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
  28    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  29    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  30    async fn get_user_by_github_account(
  31        &self,
  32        github_login: &str,
  33        github_user_id: Option<i32>,
  34    ) -> Result<Option<User>>;
  35    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  36    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  37    async fn destroy_user(&self, id: UserId) -> Result<()>;
  38
  39    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
  40    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  41    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  42    async fn create_invite_from_code(
  43        &self,
  44        code: &str,
  45        email_address: &str,
  46        device_id: Option<&str>,
  47    ) -> Result<Invite>;
  48
  49    async fn create_signup(&self, signup: Signup) -> Result<()>;
  50    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
  51    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
  52    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
  53    async fn create_user_from_invite(
  54        &self,
  55        invite: &Invite,
  56        user: NewUserParams,
  57    ) -> Result<Option<NewUserResult>>;
  58
  59    /// Registers a new project for the given user.
  60    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  61
  62    /// Unregisters a project for the given project id.
  63    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  64
  65    /// Update file counts by extension for the given project and worktree.
  66    async fn update_worktree_extensions(
  67        &self,
  68        project_id: ProjectId,
  69        worktree_id: u64,
  70        extensions: HashMap<String, u32>,
  71    ) -> Result<()>;
  72
  73    /// Get the file counts on the given project keyed by their worktree and extension.
  74    async fn get_project_extensions(
  75        &self,
  76        project_id: ProjectId,
  77    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  78
  79    /// Record which users have been active in which projects during
  80    /// a given period of time.
  81    async fn record_user_activity(
  82        &self,
  83        time_period: Range<OffsetDateTime>,
  84        active_projects: &[(UserId, ProjectId)],
  85    ) -> Result<()>;
  86
  87    /// Get the number of users who have been active in the given
  88    /// time period for at least the given time duration.
  89    async fn get_active_user_count(
  90        &self,
  91        time_period: Range<OffsetDateTime>,
  92        min_duration: Duration,
  93        only_collaborative: bool,
  94    ) -> Result<usize>;
  95
  96    /// Get the users that have been most active during the given time period,
  97    /// along with the amount of time they have been active in each project.
  98    async fn get_top_users_activity_summary(
  99        &self,
 100        time_period: Range<OffsetDateTime>,
 101        max_user_count: usize,
 102    ) -> Result<Vec<UserActivitySummary>>;
 103
 104    /// Get the project activity for the given user and time period.
 105    async fn get_user_activity_timeline(
 106        &self,
 107        time_period: Range<OffsetDateTime>,
 108        user_id: UserId,
 109    ) -> Result<Vec<UserActivityPeriod>>;
 110
 111    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
 112    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
 113    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 114    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 115    async fn dismiss_contact_notification(
 116        &self,
 117        responder_id: UserId,
 118        requester_id: UserId,
 119    ) -> Result<()>;
 120    async fn respond_to_contact_request(
 121        &self,
 122        responder_id: UserId,
 123        requester_id: UserId,
 124        accept: bool,
 125    ) -> Result<()>;
 126
 127    async fn create_access_token_hash(
 128        &self,
 129        user_id: UserId,
 130        access_token_hash: &str,
 131        max_access_token_count: usize,
 132    ) -> Result<()>;
 133    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 134
 135    #[cfg(any(test, feature = "seed-support"))]
 136    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 137    #[cfg(any(test, feature = "seed-support"))]
 138    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 139    #[cfg(any(test, feature = "seed-support"))]
 140    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 141    #[cfg(any(test, feature = "seed-support"))]
 142    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 143    #[cfg(any(test, feature = "seed-support"))]
 144
 145    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 146    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 147    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 148        -> Result<bool>;
 149
 150    #[cfg(any(test, feature = "seed-support"))]
 151    async fn add_channel_member(
 152        &self,
 153        channel_id: ChannelId,
 154        user_id: UserId,
 155        is_admin: bool,
 156    ) -> Result<()>;
 157    async fn create_channel_message(
 158        &self,
 159        channel_id: ChannelId,
 160        sender_id: UserId,
 161        body: &str,
 162        timestamp: OffsetDateTime,
 163        nonce: u128,
 164    ) -> Result<MessageId>;
 165    async fn get_channel_messages(
 166        &self,
 167        channel_id: ChannelId,
 168        count: usize,
 169        before_id: Option<MessageId>,
 170    ) -> Result<Vec<ChannelMessage>>;
 171
 172    #[cfg(test)]
 173    async fn teardown(&self, url: &str);
 174
 175    #[cfg(test)]
 176    fn as_fake(&self) -> Option<&FakeDb>;
 177}
 178
 179#[cfg(any(test, debug_assertions))]
 180pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
 181    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 182
 183pub const TEST_MIGRATIONS_PATH: Option<&'static str> =
 184    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"));
 185
 186#[cfg(not(any(test, debug_assertions)))]
 187pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
 188
 189pub struct RealDb {
 190    pool: sqlx::SqlitePool,
 191}
 192
 193macro_rules! test_support {
 194    ($self:ident, { $($token:tt)* }) => {{
 195        let body = async {
 196            $($token)*
 197        };
 198
 199        if cfg!(test) {
 200            tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body)
 201        } else {
 202            body.await
 203        }
 204    }};
 205}
 206
 207impl RealDb {
 208    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 209        eprintln!("{url}");
 210        let pool = sqlx::sqlite::SqlitePoolOptions::new()
 211            .max_connections(1)
 212            .connect(url)
 213            .await?;
 214        Ok(Self { pool })
 215    }
 216
 217    pub async fn migrate(
 218        &self,
 219        migrations_path: &Path,
 220        ignore_checksum_mismatch: bool,
 221    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 222        let migrations = MigrationSource::resolve(migrations_path)
 223            .await
 224            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 225
 226        let mut conn = self.pool.acquire().await?;
 227
 228        conn.ensure_migrations_table().await?;
 229        let applied_migrations: HashMap<_, _> = conn
 230            .list_applied_migrations()
 231            .await?
 232            .into_iter()
 233            .map(|m| (m.version, m))
 234            .collect();
 235
 236        let mut new_migrations = Vec::new();
 237        for migration in migrations {
 238            match applied_migrations.get(&migration.version) {
 239                Some(applied_migration) => {
 240                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 241                    {
 242                        Err(anyhow!(
 243                            "checksum mismatch for applied migration {}",
 244                            migration.description
 245                        ))?;
 246                    }
 247                }
 248                None => {
 249                    let elapsed = conn.apply(&migration).await?;
 250                    new_migrations.push((migration, elapsed));
 251                }
 252            }
 253        }
 254
 255        Ok(new_migrations)
 256    }
 257
 258    pub fn fuzzy_like_string(string: &str) -> String {
 259        let mut result = String::with_capacity(string.len() * 2 + 1);
 260        for c in string.chars() {
 261            if c.is_alphanumeric() {
 262                result.push('%');
 263                result.push(c);
 264            }
 265        }
 266        result.push('%');
 267        result
 268    }
 269}
 270
 271#[async_trait]
 272impl Db for RealDb {
 273    // users
 274
 275    async fn create_user(
 276        &self,
 277        email_address: &str,
 278        admin: bool,
 279        params: NewUserParams,
 280    ) -> Result<NewUserResult> {
 281        test_support!(self, {
 282            let query = "
 283                INSERT INTO users (email_address, github_login, github_user_id, admin)
 284                VALUES ($1, $2, $3, $4)
 285                -- ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 286                RETURNING id, 'the-metrics-id'
 287            ";
 288
 289            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
 290                .bind(email_address)
 291                .bind(params.github_login)
 292                .bind(params.github_user_id)
 293                .bind(admin)
 294                .fetch_one(&self.pool)
 295                .await?;
 296            Ok(NewUserResult {
 297                user_id,
 298                metrics_id,
 299                signup_device_id: None,
 300                inviting_user_id: None,
 301            })
 302        })
 303    }
 304
 305    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 306        test_support!(self, {
 307            let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 308            Ok(sqlx::query_as(query)
 309                .bind(limit as i32)
 310                .bind((page * limit) as i32)
 311                .fetch_all(&self.pool)
 312                .await?)
 313        })
 314    }
 315
 316    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 317        test_support!(self, {
 318            let like_string = Self::fuzzy_like_string(name_query);
 319            let query = "
 320                SELECT users.*
 321                FROM users
 322                WHERE github_login ILIKE $1
 323                ORDER BY github_login <-> $2
 324                LIMIT $3
 325            ";
 326            Ok(sqlx::query_as(query)
 327                .bind(like_string)
 328                .bind(name_query)
 329                .bind(limit as i32)
 330                .fetch_all(&self.pool)
 331                .await?)
 332        })
 333    }
 334
 335    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 336        test_support!(self, {
 337            let query = "
 338                SELECT users.*
 339                FROM users
 340                WHERE id = $1
 341                LIMIT 1
 342            ";
 343            Ok(sqlx::query_as(query)
 344                .bind(&id)
 345                .fetch_optional(&self.pool)
 346                .await?)
 347        })
 348    }
 349
 350    async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 351        test_support!(self, {
 352            let query = "
 353                SELECT metrics_id::text
 354                FROM users
 355                WHERE id = $1
 356            ";
 357            Ok(sqlx::query_scalar(query)
 358                .bind(id)
 359                .fetch_one(&self.pool)
 360                .await?)
 361        })
 362    }
 363
 364    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 365        test_support!(self, {
 366            let query = "
 367                SELECT users.*
 368                FROM users
 369                WHERE users.id IN (SELECT value from json_each($1))
 370            ";
 371            Ok(sqlx::query_as(query)
 372                .bind(&serde_json::json!(ids))
 373                .fetch_all(&self.pool)
 374                .await?)
 375        })
 376    }
 377
 378    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 379        test_support!(self, {
 380            let query = format!(
 381                "
 382                SELECT users.*
 383                FROM users
 384                WHERE invite_count = 0
 385                AND inviter_id IS{} NULL
 386                ",
 387                if invited_by_another_user { " NOT" } else { "" }
 388            );
 389
 390            Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 391        })
 392    }
 393
 394    async fn get_user_by_github_account(
 395        &self,
 396        github_login: &str,
 397        github_user_id: Option<i32>,
 398    ) -> Result<Option<User>> {
 399        test_support!(self, {
 400            if let Some(github_user_id) = github_user_id {
 401                let mut user = sqlx::query_as::<_, User>(
 402                    "
 403                    UPDATE users
 404                    SET github_login = $1
 405                    WHERE github_user_id = $2
 406                    RETURNING *
 407                    ",
 408                )
 409                .bind(github_login)
 410                .bind(github_user_id)
 411                .fetch_optional(&self.pool)
 412                .await?;
 413
 414                if user.is_none() {
 415                    user = sqlx::query_as::<_, User>(
 416                        "
 417                        UPDATE users
 418                        SET github_user_id = $1
 419                        WHERE github_login = $2
 420                        RETURNING *
 421                        ",
 422                    )
 423                    .bind(github_user_id)
 424                    .bind(github_login)
 425                    .fetch_optional(&self.pool)
 426                    .await?;
 427                }
 428
 429                Ok(user)
 430            } else {
 431                let user = sqlx::query_as(
 432                    "
 433                    SELECT * FROM users
 434                    WHERE github_login = $1
 435                    LIMIT 1
 436                    ",
 437                )
 438                .bind(github_login)
 439                .fetch_optional(&self.pool)
 440                .await?;
 441                Ok(user)
 442            }
 443        })
 444    }
 445
 446    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 447        test_support!(self, {
 448            let query = "UPDATE users SET admin = $1 WHERE id = $2";
 449            Ok(sqlx::query(query)
 450                .bind(is_admin)
 451                .bind(id.0)
 452                .execute(&self.pool)
 453                .await
 454                .map(drop)?)
 455        })
 456    }
 457
 458    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 459        test_support!(self, {
 460            let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 461            Ok(sqlx::query(query)
 462                .bind(connected_once)
 463                .bind(id.0)
 464                .execute(&self.pool)
 465                .await
 466                .map(drop)?)
 467        })
 468    }
 469
 470    async fn destroy_user(&self, id: UserId) -> Result<()> {
 471        test_support!(self, {
 472            let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 473            sqlx::query(query)
 474                .bind(id.0)
 475                .execute(&self.pool)
 476                .await
 477                .map(drop)?;
 478            let query = "DELETE FROM users WHERE id = $1;";
 479            Ok(sqlx::query(query)
 480                .bind(id.0)
 481                .execute(&self.pool)
 482                .await
 483                .map(drop)?)
 484        })
 485    }
 486
 487    // signups
 488
 489    async fn create_signup(&self, signup: Signup) -> Result<()> {
 490        test_support!(self, {
 491            sqlx::query(
 492                "
 493                INSERT INTO signups
 494                (
 495                    email_address,
 496                    email_confirmation_code,
 497                    email_confirmation_sent,
 498                    platform_linux,
 499                    platform_mac,
 500                    platform_windows,
 501                    platform_unknown,
 502                    editor_features,
 503                    programming_languages,
 504                    device_id
 505                )
 506                VALUES
 507                    ($1, $2, FALSE, $3, $4, $5, FALSE, $6)
 508                RETURNING id
 509                ",
 510            )
 511            .bind(&signup.email_address)
 512            .bind(&random_email_confirmation_code())
 513            .bind(&signup.platform_linux)
 514            .bind(&signup.platform_mac)
 515            .bind(&signup.platform_windows)
 516            // .bind(&signup.editor_features)
 517            // .bind(&signup.programming_languages)
 518            .bind(&signup.device_id)
 519            .execute(&self.pool)
 520            .await?;
 521            Ok(())
 522        })
 523    }
 524
 525    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 526        test_support!(self, {
 527            Ok(sqlx::query_as(
 528                "
 529                SELECT
 530                    COUNT(*) as count,
 531                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 532                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 533                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 534                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 535                FROM (
 536                    SELECT *
 537                    FROM signups
 538                    WHERE
 539                        NOT email_confirmation_sent
 540                ) AS unsent
 541                ",
 542            )
 543            .fetch_one(&self.pool)
 544            .await?)
 545        })
 546    }
 547
 548    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 549        test_support!(self, {
 550            Ok(sqlx::query_as(
 551                "
 552                SELECT
 553                    email_address, email_confirmation_code
 554                FROM signups
 555                WHERE
 556                    NOT email_confirmation_sent AND
 557                    (platform_mac OR platform_unknown)
 558                LIMIT $1
 559                ",
 560            )
 561            .bind(count as i32)
 562            .fetch_all(&self.pool)
 563            .await?)
 564        })
 565    }
 566
 567    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 568        test_support!(self, {
 569            // sqlx::query(
 570            //     "
 571            //     UPDATE signups
 572            //     SET email_confirmation_sent = TRUE
 573            //     WHERE email_address = ANY ($1)
 574            //     ",
 575            // )
 576            // .bind(
 577            //     &invites
 578            //         .iter()
 579            //         .map(|s| s.email_address.as_str())
 580            //         .collect::<Vec<_>>(),
 581            // )
 582            // .execute(&self.pool)
 583            // .await?;
 584            Ok(())
 585        })
 586    }
 587
 588    async fn create_user_from_invite(
 589        &self,
 590        invite: &Invite,
 591        user: NewUserParams,
 592    ) -> Result<Option<NewUserResult>> {
 593        test_support!(self, {
 594            let mut tx = self.pool.begin().await?;
 595
 596            let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
 597                i32,
 598                Option<UserId>,
 599                Option<UserId>,
 600                Option<String>,
 601            ) = sqlx::query_as(
 602                "
 603                SELECT id, user_id, inviting_user_id, device_id
 604                FROM signups
 605                WHERE
 606                    email_address = $1 AND
 607                    email_confirmation_code = $2
 608                ",
 609            )
 610            .bind(&invite.email_address)
 611            .bind(&invite.email_confirmation_code)
 612            .fetch_optional(&mut tx)
 613            .await?
 614            .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 615
 616            if existing_user_id.is_some() {
 617                return Ok(None);
 618            }
 619
 620            let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
 621                "
 622                INSERT INTO users
 623                (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 624                VALUES
 625                ($1, $2, $3, FALSE, $4, $5)
 626                ON CONFLICT (github_login) DO UPDATE SET
 627                    email_address = excluded.email_address,
 628                    github_user_id = excluded.github_user_id,
 629                    admin = excluded.admin
 630                RETURNING id, metrics_id::text
 631                ",
 632            )
 633            .bind(&invite.email_address)
 634            .bind(&user.github_login)
 635            .bind(&user.github_user_id)
 636            .bind(&user.invite_count)
 637            .bind(random_invite_code())
 638            .fetch_one(&mut tx)
 639            .await?;
 640
 641            sqlx::query(
 642                "
 643                UPDATE signups
 644                SET user_id = $1
 645                WHERE id = $2
 646                ",
 647            )
 648            .bind(&user_id)
 649            .bind(&signup_id)
 650            .execute(&mut tx)
 651            .await?;
 652
 653            if let Some(inviting_user_id) = inviting_user_id {
 654                let id: Option<UserId> = sqlx::query_scalar(
 655                    "
 656                    UPDATE users
 657                    SET invite_count = invite_count - 1
 658                    WHERE id = $1 AND invite_count > 0
 659                    RETURNING id
 660                    ",
 661                )
 662                .bind(&inviting_user_id)
 663                .fetch_optional(&mut tx)
 664                .await?;
 665
 666                if id.is_none() {
 667                    Err(Error::Http(
 668                        StatusCode::UNAUTHORIZED,
 669                        "no invites remaining".to_string(),
 670                    ))?;
 671                }
 672
 673                sqlx::query(
 674                    "
 675                    INSERT INTO contacts
 676                        (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 677                    VALUES
 678                        ($1, $2, TRUE, TRUE, TRUE)
 679                    ON CONFLICT DO NOTHING
 680                    ",
 681                )
 682                .bind(inviting_user_id)
 683                .bind(user_id)
 684                .execute(&mut tx)
 685                .await?;
 686            }
 687
 688            tx.commit().await?;
 689            Ok(Some(NewUserResult {
 690                user_id,
 691                metrics_id,
 692                inviting_user_id,
 693                signup_device_id,
 694            }))
 695        })
 696    }
 697
 698    // invite codes
 699
 700    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 701        test_support!(self, {
 702            let mut tx = self.pool.begin().await?;
 703            if count > 0 {
 704                sqlx::query(
 705                    "
 706                    UPDATE users
 707                    SET invite_code = $1
 708                    WHERE id = $2 AND invite_code IS NULL
 709                ",
 710                )
 711                .bind(random_invite_code())
 712                .bind(id)
 713                .execute(&mut tx)
 714                .await?;
 715            }
 716
 717            sqlx::query(
 718                "
 719                UPDATE users
 720                SET invite_count = $1
 721                WHERE id = $2
 722                ",
 723            )
 724            .bind(count as i32)
 725            .bind(id)
 726            .execute(&mut tx)
 727            .await?;
 728            tx.commit().await?;
 729            Ok(())
 730        })
 731    }
 732
 733    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 734        test_support!(self, {
 735            let result: Option<(String, i32)> = sqlx::query_as(
 736                "
 737                    SELECT invite_code, invite_count
 738                    FROM users
 739                    WHERE id = $1 AND invite_code IS NOT NULL 
 740                ",
 741            )
 742            .bind(id)
 743            .fetch_optional(&self.pool)
 744            .await?;
 745            if let Some((code, count)) = result {
 746                Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 747            } else {
 748                Ok(None)
 749            }
 750        })
 751    }
 752
 753    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 754        test_support!(self, {
 755            sqlx::query_as(
 756                "
 757                    SELECT *
 758                    FROM users
 759                    WHERE invite_code = $1
 760                ",
 761            )
 762            .bind(code)
 763            .fetch_optional(&self.pool)
 764            .await?
 765            .ok_or_else(|| {
 766                Error::Http(
 767                    StatusCode::NOT_FOUND,
 768                    "that invite code does not exist".to_string(),
 769                )
 770            })
 771        })
 772    }
 773
 774    async fn create_invite_from_code(
 775        &self,
 776        code: &str,
 777        email_address: &str,
 778        device_id: Option<&str>,
 779    ) -> Result<Invite> {
 780        test_support!(self, {
 781            let mut tx = self.pool.begin().await?;
 782
 783            let existing_user: Option<UserId> = sqlx::query_scalar(
 784                "
 785                SELECT id
 786                FROM users
 787                WHERE email_address = $1
 788                ",
 789            )
 790            .bind(email_address)
 791            .fetch_optional(&mut tx)
 792            .await?;
 793            if existing_user.is_some() {
 794                Err(anyhow!("email address is already in use"))?;
 795            }
 796
 797            let row: Option<(UserId, i32)> = sqlx::query_as(
 798                "
 799                SELECT id, invite_count
 800                FROM users
 801                WHERE invite_code = $1
 802                ",
 803            )
 804            .bind(code)
 805            .fetch_optional(&mut tx)
 806            .await?;
 807
 808            let (inviter_id, invite_count) = match row {
 809                Some(row) => row,
 810                None => Err(Error::Http(
 811                    StatusCode::NOT_FOUND,
 812                    "invite code not found".to_string(),
 813                ))?,
 814            };
 815
 816            if invite_count == 0 {
 817                Err(Error::Http(
 818                    StatusCode::UNAUTHORIZED,
 819                    "no invites remaining".to_string(),
 820                ))?;
 821            }
 822
 823            let email_confirmation_code: String = sqlx::query_scalar(
 824                "
 825                INSERT INTO signups
 826                (
 827                    email_address,
 828                    email_confirmation_code,
 829                    email_confirmation_sent,
 830                    inviting_user_id,
 831                    platform_linux,
 832                    platform_mac,
 833                    platform_windows,
 834                    platform_unknown,
 835                    device_id
 836                )
 837                VALUES
 838                    ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
 839                ON CONFLICT (email_address)
 840                DO UPDATE SET
 841                    inviting_user_id = excluded.inviting_user_id
 842                RETURNING email_confirmation_code
 843                ",
 844            )
 845            .bind(&email_address)
 846            .bind(&random_email_confirmation_code())
 847            .bind(&inviter_id)
 848            .bind(&device_id)
 849            .fetch_one(&mut tx)
 850            .await?;
 851
 852            tx.commit().await?;
 853
 854            Ok(Invite {
 855                email_address: email_address.into(),
 856                email_confirmation_code,
 857            })
 858        })
 859    }
 860
 861    // projects
 862
 863    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 864        test_support!(self, {
 865            Ok(sqlx::query_scalar(
 866                "
 867                INSERT INTO projects(host_user_id)
 868                VALUES ($1)
 869                RETURNING id
 870                ",
 871            )
 872            .bind(host_user_id)
 873            .fetch_one(&self.pool)
 874            .await
 875            .map(ProjectId)?)
 876        })
 877    }
 878
 879    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 880        test_support!(self, {
 881            sqlx::query(
 882                "
 883                UPDATE projects
 884                SET unregistered = TRUE
 885                WHERE id = $1
 886                ",
 887            )
 888            .bind(project_id)
 889            .execute(&self.pool)
 890            .await?;
 891            Ok(())
 892        })
 893    }
 894
 895    async fn update_worktree_extensions(
 896        &self,
 897        project_id: ProjectId,
 898        worktree_id: u64,
 899        extensions: HashMap<String, u32>,
 900    ) -> Result<()> {
 901        test_support!(self, {
 902            if extensions.is_empty() {
 903                return Ok(());
 904            }
 905
 906            let mut query = QueryBuilder::new(
 907                "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 908            );
 909            query.push_values(extensions, |mut query, (extension, count)| {
 910                query
 911                    .push_bind(project_id)
 912                    .push_bind(worktree_id as i32)
 913                    .push_bind(extension)
 914                    .push_bind(count as i32);
 915            });
 916            query.push(
 917                "
 918                ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 919                count = excluded.count
 920                ",
 921            );
 922            query.build().execute(&self.pool).await?;
 923
 924            Ok(())
 925        })
 926    }
 927
 928    async fn get_project_extensions(
 929        &self,
 930        project_id: ProjectId,
 931    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 932        test_support!(self, {
 933            #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 934            struct WorktreeExtension {
 935                worktree_id: i32,
 936                extension: String,
 937                count: i32,
 938            }
 939
 940            let query = "
 941                SELECT worktree_id, extension, count
 942                FROM worktree_extensions
 943                WHERE project_id = $1
 944            ";
 945            let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 946                .bind(&project_id)
 947                .fetch_all(&self.pool)
 948                .await?;
 949
 950            let mut extension_counts = HashMap::default();
 951            for count in counts {
 952                extension_counts
 953                    .entry(count.worktree_id as u64)
 954                    .or_insert_with(HashMap::default)
 955                    .insert(count.extension, count.count as usize);
 956            }
 957            Ok(extension_counts)
 958        })
 959    }
 960
 961    async fn record_user_activity(
 962        &self,
 963        time_period: Range<OffsetDateTime>,
 964        projects: &[(UserId, ProjectId)],
 965    ) -> Result<()> {
 966        test_support!(self, {
 967            let query = "
 968                INSERT INTO project_activity_periods
 969                (ended_at, duration_millis, user_id, project_id)
 970                VALUES
 971                ($1, $2, $3, $4);
 972            ";
 973
 974            let mut tx = self.pool.begin().await?;
 975            let duration_millis =
 976                ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 977            for (user_id, project_id) in projects {
 978                sqlx::query(query)
 979                    .bind(time_period.end)
 980                    .bind(duration_millis)
 981                    .bind(user_id)
 982                    .bind(project_id)
 983                    .execute(&mut tx)
 984                    .await?;
 985            }
 986            tx.commit().await?;
 987
 988            Ok(())
 989        })
 990    }
 991
 992    async fn get_active_user_count(
 993        &self,
 994        time_period: Range<OffsetDateTime>,
 995        min_duration: Duration,
 996        only_collaborative: bool,
 997    ) -> Result<usize> {
 998        test_support!(self, {
 999            let mut with_clause = String::new();
1000            with_clause.push_str("WITH\n");
1001            with_clause.push_str(
1002                "
1003                project_durations AS (
1004                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1005                    FROM project_activity_periods
1006                    WHERE $1 < ended_at AND ended_at <= $2
1007                    GROUP BY user_id, project_id
1008                ),
1009                ",
1010            );
1011            with_clause.push_str(
1012                "
1013                project_collaborators as (
1014                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1015                    FROM project_durations
1016                    GROUP BY project_id
1017                ),
1018                ",
1019            );
1020
1021            if only_collaborative {
1022                with_clause.push_str(
1023                    "
1024                    user_durations AS (
1025                        SELECT user_id, SUM(project_duration) as total_duration
1026                        FROM project_durations, project_collaborators
1027                        WHERE
1028                            project_durations.project_id = project_collaborators.project_id AND
1029                            max_collaborators > 1
1030                        GROUP BY user_id
1031                        ORDER BY total_duration DESC
1032                        LIMIT $3
1033                    )
1034                    ",
1035                );
1036            } else {
1037                with_clause.push_str(
1038                    "
1039                    user_durations AS (
1040                        SELECT user_id, SUM(project_duration) as total_duration
1041                        FROM project_durations
1042                        GROUP BY user_id
1043                        ORDER BY total_duration DESC
1044                        LIMIT $3
1045                    )
1046                    ",
1047                );
1048            }
1049
1050            let query = format!(
1051                "
1052                {with_clause}
1053                SELECT count(user_durations.user_id)
1054                FROM user_durations
1055                WHERE user_durations.total_duration >= $3
1056                "
1057            );
1058
1059            let count: i64 = sqlx::query_scalar(&query)
1060                .bind(time_period.start)
1061                .bind(time_period.end)
1062                .bind(min_duration.as_millis() as i64)
1063                .fetch_one(&self.pool)
1064                .await?;
1065            Ok(count as usize)
1066        })
1067    }
1068
1069    async fn get_top_users_activity_summary(
1070        &self,
1071        time_period: Range<OffsetDateTime>,
1072        max_user_count: usize,
1073    ) -> Result<Vec<UserActivitySummary>> {
1074        test_support!(self, {
1075            let query = "
1076                WITH
1077                    project_durations AS (
1078                        SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1079                        FROM project_activity_periods
1080                        WHERE $1 < ended_at AND ended_at <= $2
1081                        GROUP BY user_id, project_id
1082                    ),
1083                    user_durations AS (
1084                        SELECT user_id, SUM(project_duration) as total_duration
1085                        FROM project_durations
1086                        GROUP BY user_id
1087                        ORDER BY total_duration DESC
1088                        LIMIT $3
1089                    ),
1090                    project_collaborators as (
1091                        SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1092                        FROM project_durations
1093                        GROUP BY project_id
1094                    )
1095                SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1096                FROM user_durations, project_durations, project_collaborators, users
1097                WHERE
1098                    user_durations.user_id = project_durations.user_id AND
1099                    user_durations.user_id = users.id AND
1100                    project_durations.project_id = project_collaborators.project_id
1101                ORDER BY total_duration DESC, user_id ASC, project_id ASC
1102            ";
1103
1104            let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1105                .bind(time_period.start)
1106                .bind(time_period.end)
1107                .bind(max_user_count as i32)
1108                .fetch(&self.pool);
1109
1110            let mut result = Vec::<UserActivitySummary>::new();
1111            while let Some(row) = rows.next().await {
1112                let (user_id, github_login, project_id, duration_millis, project_collaborators) =
1113                    row?;
1114                let project_id = project_id;
1115                let duration = Duration::from_millis(duration_millis as u64);
1116                let project_activity = ProjectActivitySummary {
1117                    id: project_id,
1118                    duration,
1119                    max_collaborators: project_collaborators as usize,
1120                };
1121                if let Some(last_summary) = result.last_mut() {
1122                    if last_summary.id == user_id {
1123                        last_summary.project_activity.push(project_activity);
1124                        continue;
1125                    }
1126                }
1127                result.push(UserActivitySummary {
1128                    id: user_id,
1129                    project_activity: vec![project_activity],
1130                    github_login,
1131                });
1132            }
1133
1134            Ok(result)
1135        })
1136    }
1137
1138    async fn get_user_activity_timeline(
1139        &self,
1140        time_period: Range<OffsetDateTime>,
1141        user_id: UserId,
1142    ) -> Result<Vec<UserActivityPeriod>> {
1143        test_support!(self, {
1144            const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1145
1146            let query = "
1147                SELECT
1148                    project_activity_periods.ended_at,
1149                    project_activity_periods.duration_millis,
1150                    project_activity_periods.project_id,
1151                    worktree_extensions.extension,
1152                    worktree_extensions.count
1153                FROM project_activity_periods
1154                LEFT OUTER JOIN
1155                    worktree_extensions
1156                ON
1157                    project_activity_periods.project_id = worktree_extensions.project_id
1158                WHERE
1159                    project_activity_periods.user_id = $1 AND
1160                    $2 < project_activity_periods.ended_at AND
1161                    project_activity_periods.ended_at <= $3
1162                ORDER BY project_activity_periods.id ASC
1163            ";
1164
1165            let mut rows = sqlx::query_as::<
1166                _,
1167                (
1168                    PrimitiveDateTime,
1169                    i32,
1170                    ProjectId,
1171                    Option<String>,
1172                    Option<i32>,
1173                ),
1174            >(query)
1175            .bind(user_id)
1176            .bind(time_period.start)
1177            .bind(time_period.end)
1178            .fetch(&self.pool);
1179
1180            let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1181            while let Some(row) = rows.next().await {
1182                let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1183                let ended_at = ended_at.assume_utc();
1184                let duration = Duration::from_millis(duration_millis as u64);
1185                let started_at = ended_at - duration;
1186                let project_time_periods = time_periods.entry(project_id).or_default();
1187
1188                if let Some(prev_duration) = project_time_periods.last_mut() {
1189                    if started_at <= prev_duration.end + COALESCE_THRESHOLD
1190                        && ended_at >= prev_duration.start
1191                    {
1192                        prev_duration.end = cmp::max(prev_duration.end, ended_at);
1193                    } else {
1194                        project_time_periods.push(UserActivityPeriod {
1195                            project_id,
1196                            start: started_at,
1197                            end: ended_at,
1198                            extensions: Default::default(),
1199                        });
1200                    }
1201                } else {
1202                    project_time_periods.push(UserActivityPeriod {
1203                        project_id,
1204                        start: started_at,
1205                        end: ended_at,
1206                        extensions: Default::default(),
1207                    });
1208                }
1209
1210                if let Some((extension, extension_count)) = extension.zip(extension_count) {
1211                    project_time_periods
1212                        .last_mut()
1213                        .unwrap()
1214                        .extensions
1215                        .insert(extension, extension_count as usize);
1216                }
1217            }
1218
1219            let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1220            durations.sort_unstable_by_key(|duration| duration.start);
1221            Ok(durations)
1222        })
1223    }
1224
1225    // contacts
1226
1227    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1228        test_support!(self, {
1229            let query = "
1230                SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1231                FROM contacts
1232                WHERE user_id_a = $1 OR user_id_b = $1;
1233            ";
1234
1235            let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1236                .bind(user_id)
1237                .fetch(&self.pool);
1238
1239            let mut contacts = Vec::new();
1240            while let Some(row) = rows.next().await {
1241                let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1242
1243                if user_id_a == user_id {
1244                    if accepted {
1245                        contacts.push(Contact::Accepted {
1246                            user_id: user_id_b,
1247                            should_notify: should_notify && a_to_b,
1248                        });
1249                    } else if a_to_b {
1250                        contacts.push(Contact::Outgoing { user_id: user_id_b })
1251                    } else {
1252                        contacts.push(Contact::Incoming {
1253                            user_id: user_id_b,
1254                            should_notify,
1255                        });
1256                    }
1257                } else if accepted {
1258                    contacts.push(Contact::Accepted {
1259                        user_id: user_id_a,
1260                        should_notify: should_notify && !a_to_b,
1261                    });
1262                } else if a_to_b {
1263                    contacts.push(Contact::Incoming {
1264                        user_id: user_id_a,
1265                        should_notify,
1266                    });
1267                } else {
1268                    contacts.push(Contact::Outgoing { user_id: user_id_a });
1269                }
1270            }
1271
1272            contacts.sort_unstable_by_key(|contact| contact.user_id());
1273
1274            Ok(contacts)
1275        })
1276    }
1277
1278    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1279        test_support!(self, {
1280            let (id_a, id_b) = if user_id_1 < user_id_2 {
1281                (user_id_1, user_id_2)
1282            } else {
1283                (user_id_2, user_id_1)
1284            };
1285
1286            let query = "
1287                SELECT 1 FROM contacts
1288                WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1289                LIMIT 1
1290            ";
1291            Ok(sqlx::query_scalar::<_, i32>(query)
1292                .bind(id_a.0)
1293                .bind(id_b.0)
1294                .fetch_optional(&self.pool)
1295                .await?
1296                .is_some())
1297        })
1298    }
1299
1300    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1301        test_support!(self, {
1302            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1303                (sender_id, receiver_id, true)
1304            } else {
1305                (receiver_id, sender_id, false)
1306            };
1307            let query = "
1308                INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1309                VALUES ($1, $2, $3, FALSE, TRUE)
1310                ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1311                SET
1312                    accepted = TRUE,
1313                    should_notify = FALSE
1314                WHERE
1315                    NOT contacts.accepted AND
1316                    ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1317                    (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1318            ";
1319            let result = sqlx::query(query)
1320                .bind(id_a.0)
1321                .bind(id_b.0)
1322                .bind(a_to_b)
1323                .execute(&self.pool)
1324                .await?;
1325
1326            if result.rows_affected() == 1 {
1327                Ok(())
1328            } else {
1329                Err(anyhow!("contact already requested"))?
1330            }
1331        })
1332    }
1333
1334    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1335        test_support!(self, {
1336            let (id_a, id_b) = if responder_id < requester_id {
1337                (responder_id, requester_id)
1338            } else {
1339                (requester_id, responder_id)
1340            };
1341            let query = "
1342                DELETE FROM contacts
1343                WHERE user_id_a = $1 AND user_id_b = $2;
1344            ";
1345            let result = sqlx::query(query)
1346                .bind(id_a.0)
1347                .bind(id_b.0)
1348                .execute(&self.pool)
1349                .await?;
1350
1351            if result.rows_affected() == 1 {
1352                Ok(())
1353            } else {
1354                Err(anyhow!("no such contact"))?
1355            }
1356        })
1357    }
1358
1359    async fn dismiss_contact_notification(
1360        &self,
1361        user_id: UserId,
1362        contact_user_id: UserId,
1363    ) -> Result<()> {
1364        test_support!(self, {
1365            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1366                (user_id, contact_user_id, true)
1367            } else {
1368                (contact_user_id, user_id, false)
1369            };
1370
1371            let query = "
1372                UPDATE contacts
1373                SET should_notify = FALSE
1374                WHERE
1375                    user_id_a = $1 AND user_id_b = $2 AND
1376                    (
1377                        (a_to_b = $3 AND accepted) OR
1378                        (a_to_b != $3 AND NOT accepted)
1379                    );
1380            ";
1381
1382            let result = sqlx::query(query)
1383                .bind(id_a.0)
1384                .bind(id_b.0)
1385                .bind(a_to_b)
1386                .execute(&self.pool)
1387                .await?;
1388
1389            if result.rows_affected() == 0 {
1390                Err(anyhow!("no such contact request"))?;
1391            }
1392
1393            Ok(())
1394        })
1395    }
1396
1397    async fn respond_to_contact_request(
1398        &self,
1399        responder_id: UserId,
1400        requester_id: UserId,
1401        accept: bool,
1402    ) -> Result<()> {
1403        test_support!(self, {
1404            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1405                (responder_id, requester_id, false)
1406            } else {
1407                (requester_id, responder_id, true)
1408            };
1409            let result = if accept {
1410                let query = "
1411                    UPDATE contacts
1412                    SET accepted = TRUE, should_notify = TRUE
1413                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1414                ";
1415                sqlx::query(query)
1416                    .bind(id_a.0)
1417                    .bind(id_b.0)
1418                    .bind(a_to_b)
1419                    .execute(&self.pool)
1420                    .await?
1421            } else {
1422                let query = "
1423                    DELETE FROM contacts
1424                    WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1425                ";
1426                sqlx::query(query)
1427                    .bind(id_a.0)
1428                    .bind(id_b.0)
1429                    .bind(a_to_b)
1430                    .execute(&self.pool)
1431                    .await?
1432            };
1433            if result.rows_affected() == 1 {
1434                Ok(())
1435            } else {
1436                Err(anyhow!("no such contact request"))?
1437            }
1438        })
1439    }
1440
1441    // access tokens
1442
1443    async fn create_access_token_hash(
1444        &self,
1445        user_id: UserId,
1446        access_token_hash: &str,
1447        max_access_token_count: usize,
1448    ) -> Result<()> {
1449        test_support!(self, {
1450            let insert_query = "
1451                INSERT INTO access_tokens (user_id, hash)
1452                VALUES ($1, $2);
1453            ";
1454            let cleanup_query = "
1455                DELETE FROM access_tokens
1456                WHERE id IN (
1457                    SELECT id from access_tokens
1458                    WHERE user_id = $1
1459                    ORDER BY id DESC
1460                    OFFSET $3
1461                )
1462            ";
1463
1464            let mut tx = self.pool.begin().await?;
1465            sqlx::query(insert_query)
1466                .bind(user_id.0)
1467                .bind(access_token_hash)
1468                .execute(&mut tx)
1469                .await?;
1470            sqlx::query(cleanup_query)
1471                .bind(user_id.0)
1472                .bind(access_token_hash)
1473                .bind(max_access_token_count as i32)
1474                .execute(&mut tx)
1475                .await?;
1476            Ok(tx.commit().await?)
1477        })
1478    }
1479
1480    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1481        test_support!(self, {
1482            let query = "
1483                SELECT hash
1484                FROM access_tokens
1485                WHERE user_id = $1
1486                ORDER BY id DESC
1487            ";
1488            Ok(sqlx::query_scalar(query)
1489                .bind(user_id.0)
1490                .fetch_all(&self.pool)
1491                .await?)
1492        })
1493    }
1494
1495    // orgs
1496
1497    #[allow(unused)] // Help rust-analyzer
1498    #[cfg(any(test, feature = "seed-support"))]
1499    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1500        test_support!(self, {
1501            let query = "
1502                SELECT *
1503                FROM orgs
1504                WHERE slug = $1
1505            ";
1506            Ok(sqlx::query_as(query)
1507                .bind(slug)
1508                .fetch_optional(&self.pool)
1509                .await?)
1510        })
1511    }
1512
1513    #[cfg(any(test, feature = "seed-support"))]
1514    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1515        test_support!(self, {
1516            let query = "
1517                INSERT INTO orgs (name, slug)
1518                VALUES ($1, $2)
1519                RETURNING id
1520            ";
1521            Ok(sqlx::query_scalar(query)
1522                .bind(name)
1523                .bind(slug)
1524                .fetch_one(&self.pool)
1525                .await
1526                .map(OrgId)?)
1527        })
1528    }
1529
1530    #[cfg(any(test, feature = "seed-support"))]
1531    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1532        test_support!(self, {
1533            let query = "
1534                INSERT INTO org_memberships (org_id, user_id, admin)
1535                VALUES ($1, $2, $3)
1536                ON CONFLICT DO NOTHING
1537            ";
1538            Ok(sqlx::query(query)
1539                .bind(org_id.0)
1540                .bind(user_id.0)
1541                .bind(is_admin)
1542                .execute(&self.pool)
1543                .await
1544                .map(drop)?)
1545        })
1546    }
1547
1548    // channels
1549
1550    #[cfg(any(test, feature = "seed-support"))]
1551    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1552        test_support!(self, {
1553            let query = "
1554                INSERT INTO channels (owner_id, owner_is_user, name)
1555                VALUES ($1, false, $2)
1556                RETURNING id
1557            ";
1558            Ok(sqlx::query_scalar(query)
1559                .bind(org_id.0)
1560                .bind(name)
1561                .fetch_one(&self.pool)
1562                .await
1563                .map(ChannelId)?)
1564        })
1565    }
1566
1567    #[allow(unused)] // Help rust-analyzer
1568    #[cfg(any(test, feature = "seed-support"))]
1569    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1570        test_support!(self, {
1571            let query = "
1572                SELECT *
1573                FROM channels
1574                WHERE
1575                    channels.owner_is_user = false AND
1576                    channels.owner_id = $1
1577            ";
1578            Ok(sqlx::query_as(query)
1579                .bind(org_id.0)
1580                .fetch_all(&self.pool)
1581                .await?)
1582        })
1583    }
1584
1585    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1586        test_support!(self, {
1587            let query = "
1588                SELECT
1589                    channels.*
1590                FROM
1591                    channel_memberships, channels
1592                WHERE
1593                    channel_memberships.user_id = $1 AND
1594                    channel_memberships.channel_id = channels.id
1595            ";
1596            Ok(sqlx::query_as(query)
1597                .bind(user_id.0)
1598                .fetch_all(&self.pool)
1599                .await?)
1600        })
1601    }
1602
1603    async fn can_user_access_channel(
1604        &self,
1605        user_id: UserId,
1606        channel_id: ChannelId,
1607    ) -> Result<bool> {
1608        test_support!(self, {
1609            let query = "
1610                SELECT id
1611                FROM channel_memberships
1612                WHERE user_id = $1 AND channel_id = $2
1613                LIMIT 1
1614            ";
1615            Ok(sqlx::query_scalar::<_, i32>(query)
1616                .bind(user_id.0)
1617                .bind(channel_id.0)
1618                .fetch_optional(&self.pool)
1619                .await
1620                .map(|e| e.is_some())?)
1621        })
1622    }
1623
1624    #[cfg(any(test, feature = "seed-support"))]
1625    async fn add_channel_member(
1626        &self,
1627        channel_id: ChannelId,
1628        user_id: UserId,
1629        is_admin: bool,
1630    ) -> Result<()> {
1631        test_support!(self, {
1632            let query = "
1633                INSERT INTO channel_memberships (channel_id, user_id, admin)
1634                VALUES ($1, $2, $3)
1635                ON CONFLICT DO NOTHING
1636            ";
1637            Ok(sqlx::query(query)
1638                .bind(channel_id.0)
1639                .bind(user_id.0)
1640                .bind(is_admin)
1641                .execute(&self.pool)
1642                .await
1643                .map(drop)?)
1644        })
1645    }
1646
1647    // messages
1648
1649    async fn create_channel_message(
1650        &self,
1651        channel_id: ChannelId,
1652        sender_id: UserId,
1653        body: &str,
1654        timestamp: OffsetDateTime,
1655        nonce: u128,
1656    ) -> Result<MessageId> {
1657        test_support!(self, {
1658            let query = "
1659                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1660                VALUES ($1, $2, $3, $4, $5)
1661                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1662                RETURNING id
1663            ";
1664            Ok(sqlx::query_scalar(query)
1665                .bind(channel_id.0)
1666                .bind(sender_id.0)
1667                .bind(body)
1668                .bind(timestamp)
1669                .bind(Uuid::from_u128(nonce))
1670                .fetch_one(&self.pool)
1671                .await
1672                .map(MessageId)?)
1673        })
1674    }
1675
1676    async fn get_channel_messages(
1677        &self,
1678        channel_id: ChannelId,
1679        count: usize,
1680        before_id: Option<MessageId>,
1681    ) -> Result<Vec<ChannelMessage>> {
1682        test_support!(self, {
1683            let query = r#"
1684                SELECT * FROM (
1685                    SELECT
1686                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1687                    FROM
1688                        channel_messages
1689                    WHERE
1690                        channel_id = $1 AND
1691                        id < $2
1692                    ORDER BY id DESC
1693                    LIMIT $3
1694                ) as recent_messages
1695                ORDER BY id ASC
1696            "#;
1697            Ok(sqlx::query_as(query)
1698                .bind(channel_id.0)
1699                .bind(before_id.unwrap_or(MessageId::MAX))
1700                .bind(count as i64)
1701                .fetch_all(&self.pool)
1702                .await?)
1703        })
1704    }
1705
1706    #[cfg(test)]
1707    async fn teardown(&self, url: &str) {
1708        let start = std::time::Instant::now();
1709        eprintln!("tearing down database...");
1710        test_support!(self, {
1711            use util::ResultExt;
1712
1713            let query = "
1714                SELECT pg_terminate_backend(pg_stat_activity.pid)
1715                FROM pg_stat_activity
1716                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1717            ";
1718            sqlx::query(query).execute(&self.pool).await.log_err();
1719            self.pool.close().await;
1720            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
1721                .await
1722                .log_err();
1723            eprintln!("tore down database: {:?}", start.elapsed());
1724        })
1725    }
1726
1727    #[cfg(test)]
1728    fn as_fake(&self) -> Option<&FakeDb> {
1729        None
1730    }
1731}
1732
1733macro_rules! id_type {
1734    ($name:ident) => {
1735        #[derive(
1736            Clone,
1737            Copy,
1738            Debug,
1739            Default,
1740            PartialEq,
1741            Eq,
1742            PartialOrd,
1743            Ord,
1744            Hash,
1745            sqlx::Type,
1746            Serialize,
1747            Deserialize,
1748        )]
1749        #[sqlx(transparent)]
1750        #[serde(transparent)]
1751        pub struct $name(pub i32);
1752
1753        impl $name {
1754            #[allow(unused)]
1755            pub const MAX: Self = Self(i32::MAX);
1756
1757            #[allow(unused)]
1758            pub fn from_proto(value: u64) -> Self {
1759                Self(value as i32)
1760            }
1761
1762            #[allow(unused)]
1763            pub fn to_proto(self) -> u64 {
1764                self.0 as u64
1765            }
1766        }
1767
1768        impl std::fmt::Display for $name {
1769            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1770                self.0.fmt(f)
1771            }
1772        }
1773    };
1774}
1775
1776id_type!(UserId);
1777#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1778pub struct User {
1779    pub id: UserId,
1780    pub github_login: String,
1781    pub github_user_id: Option<i32>,
1782    pub email_address: Option<String>,
1783    pub admin: bool,
1784    pub invite_code: Option<String>,
1785    pub invite_count: i32,
1786    pub connected_once: bool,
1787}
1788
1789id_type!(ProjectId);
1790#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1791pub struct Project {
1792    pub id: ProjectId,
1793    pub host_user_id: UserId,
1794    pub unregistered: bool,
1795}
1796
1797#[derive(Clone, Debug, PartialEq, Serialize)]
1798pub struct UserActivitySummary {
1799    pub id: UserId,
1800    pub github_login: String,
1801    pub project_activity: Vec<ProjectActivitySummary>,
1802}
1803
1804#[derive(Clone, Debug, PartialEq, Serialize)]
1805pub struct ProjectActivitySummary {
1806    pub id: ProjectId,
1807    pub duration: Duration,
1808    pub max_collaborators: usize,
1809}
1810
1811#[derive(Clone, Debug, PartialEq, Serialize)]
1812pub struct UserActivityPeriod {
1813    pub project_id: ProjectId,
1814    #[serde(with = "time::serde::iso8601")]
1815    pub start: OffsetDateTime,
1816    #[serde(with = "time::serde::iso8601")]
1817    pub end: OffsetDateTime,
1818    pub extensions: HashMap<String, usize>,
1819}
1820
1821id_type!(OrgId);
1822#[derive(FromRow)]
1823pub struct Org {
1824    pub id: OrgId,
1825    pub name: String,
1826    pub slug: String,
1827}
1828
1829id_type!(ChannelId);
1830#[derive(Clone, Debug, FromRow, Serialize)]
1831pub struct Channel {
1832    pub id: ChannelId,
1833    pub name: String,
1834    pub owner_id: i32,
1835    pub owner_is_user: bool,
1836}
1837
1838id_type!(MessageId);
1839#[derive(Clone, Debug, FromRow)]
1840pub struct ChannelMessage {
1841    pub id: MessageId,
1842    pub channel_id: ChannelId,
1843    pub sender_id: UserId,
1844    pub body: String,
1845    pub sent_at: OffsetDateTime,
1846    pub nonce: Uuid,
1847}
1848
1849#[derive(Clone, Debug, PartialEq, Eq)]
1850pub enum Contact {
1851    Accepted {
1852        user_id: UserId,
1853        should_notify: bool,
1854    },
1855    Outgoing {
1856        user_id: UserId,
1857    },
1858    Incoming {
1859        user_id: UserId,
1860        should_notify: bool,
1861    },
1862}
1863
1864impl Contact {
1865    pub fn user_id(&self) -> UserId {
1866        match self {
1867            Contact::Accepted { user_id, .. } => *user_id,
1868            Contact::Outgoing { user_id } => *user_id,
1869            Contact::Incoming { user_id, .. } => *user_id,
1870        }
1871    }
1872}
1873
1874#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1875pub struct IncomingContactRequest {
1876    pub requester_id: UserId,
1877    pub should_notify: bool,
1878}
1879
1880#[derive(Clone, Deserialize)]
1881pub struct Signup {
1882    pub email_address: String,
1883    pub platform_mac: bool,
1884    pub platform_windows: bool,
1885    pub platform_linux: bool,
1886    pub editor_features: Vec<String>,
1887    pub programming_languages: Vec<String>,
1888    pub device_id: Option<String>,
1889}
1890
1891#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1892pub struct WaitlistSummary {
1893    #[sqlx(default)]
1894    pub count: i64,
1895    #[sqlx(default)]
1896    pub linux_count: i64,
1897    #[sqlx(default)]
1898    pub mac_count: i64,
1899    #[sqlx(default)]
1900    pub windows_count: i64,
1901    #[sqlx(default)]
1902    pub unknown_count: i64,
1903}
1904
1905#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1906pub struct Invite {
1907    pub email_address: String,
1908    pub email_confirmation_code: String,
1909}
1910
1911#[derive(Debug, Serialize, Deserialize)]
1912pub struct NewUserParams {
1913    pub github_login: String,
1914    pub github_user_id: i32,
1915    pub invite_count: i32,
1916}
1917
1918#[derive(Debug)]
1919pub struct NewUserResult {
1920    pub user_id: UserId,
1921    pub metrics_id: String,
1922    pub inviting_user_id: Option<UserId>,
1923    pub signup_device_id: Option<String>,
1924}
1925
1926fn random_invite_code() -> String {
1927    nanoid::nanoid!(16)
1928}
1929
1930fn random_email_confirmation_code() -> String {
1931    nanoid::nanoid!(64)
1932}
1933
1934#[cfg(test)]
1935pub use test::*;
1936
1937#[cfg(test)]
1938mod test {
1939    use super::*;
1940    use anyhow::anyhow;
1941    use collections::BTreeMap;
1942    use gpui::executor::Background;
1943    use parking_lot::Mutex;
1944    use rand::prelude::*;
1945    use sqlx::{migrate::MigrateDatabase, Sqlite};
1946    use std::sync::Arc;
1947    use util::post_inc;
1948
1949    pub struct FakeDb {
1950        background: Arc<Background>,
1951        pub users: Mutex<BTreeMap<UserId, User>>,
1952        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1953        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1954        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1955        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1956        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1957        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1958        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1959        pub contacts: Mutex<Vec<FakeContact>>,
1960        next_channel_message_id: Mutex<i32>,
1961        next_user_id: Mutex<i32>,
1962        next_org_id: Mutex<i32>,
1963        next_channel_id: Mutex<i32>,
1964        next_project_id: Mutex<i32>,
1965    }
1966
1967    #[derive(Debug)]
1968    pub struct FakeContact {
1969        pub requester_id: UserId,
1970        pub responder_id: UserId,
1971        pub accepted: bool,
1972        pub should_notify: bool,
1973    }
1974
1975    impl FakeDb {
1976        pub fn new(background: Arc<Background>) -> Self {
1977            Self {
1978                background,
1979                users: Default::default(),
1980                next_user_id: Mutex::new(0),
1981                projects: Default::default(),
1982                worktree_extensions: Default::default(),
1983                next_project_id: Mutex::new(1),
1984                orgs: Default::default(),
1985                next_org_id: Mutex::new(1),
1986                org_memberships: Default::default(),
1987                channels: Default::default(),
1988                next_channel_id: Mutex::new(1),
1989                channel_memberships: Default::default(),
1990                channel_messages: Default::default(),
1991                next_channel_message_id: Mutex::new(1),
1992                contacts: Default::default(),
1993            }
1994        }
1995    }
1996
1997    #[async_trait]
1998    impl Db for FakeDb {
1999        async fn create_user(
2000            &self,
2001            email_address: &str,
2002            admin: bool,
2003            params: NewUserParams,
2004        ) -> Result<NewUserResult> {
2005            self.background.simulate_random_delay().await;
2006
2007            let mut users = self.users.lock();
2008            let user_id = if let Some(user) = users
2009                .values()
2010                .find(|user| user.github_login == params.github_login)
2011            {
2012                user.id
2013            } else {
2014                let id = post_inc(&mut *self.next_user_id.lock());
2015                let user_id = UserId(id);
2016                users.insert(
2017                    user_id,
2018                    User {
2019                        id: user_id,
2020                        github_login: params.github_login,
2021                        github_user_id: Some(params.github_user_id),
2022                        email_address: Some(email_address.to_string()),
2023                        admin,
2024                        invite_code: None,
2025                        invite_count: 0,
2026                        connected_once: false,
2027                    },
2028                );
2029                user_id
2030            };
2031            Ok(NewUserResult {
2032                user_id,
2033                metrics_id: "the-metrics-id".to_string(),
2034                inviting_user_id: None,
2035                signup_device_id: None,
2036            })
2037        }
2038
2039        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2040            unimplemented!()
2041        }
2042
2043        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2044            unimplemented!()
2045        }
2046
2047        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2048            self.background.simulate_random_delay().await;
2049            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2050        }
2051
2052        async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
2053            Ok("the-metrics-id".to_string())
2054        }
2055
2056        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2057            self.background.simulate_random_delay().await;
2058            let users = self.users.lock();
2059            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2060        }
2061
2062        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2063            unimplemented!()
2064        }
2065
2066        async fn get_user_by_github_account(
2067            &self,
2068            github_login: &str,
2069            github_user_id: Option<i32>,
2070        ) -> Result<Option<User>> {
2071            self.background.simulate_random_delay().await;
2072            if let Some(github_user_id) = github_user_id {
2073                for user in self.users.lock().values_mut() {
2074                    if user.github_user_id == Some(github_user_id) {
2075                        user.github_login = github_login.into();
2076                        return Ok(Some(user.clone()));
2077                    }
2078                    if user.github_login == github_login {
2079                        user.github_user_id = Some(github_user_id);
2080                        return Ok(Some(user.clone()));
2081                    }
2082                }
2083                Ok(None)
2084            } else {
2085                Ok(self
2086                    .users
2087                    .lock()
2088                    .values()
2089                    .find(|user| user.github_login == github_login)
2090                    .cloned())
2091            }
2092        }
2093
2094        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2095            unimplemented!()
2096        }
2097
2098        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2099            self.background.simulate_random_delay().await;
2100            let mut users = self.users.lock();
2101            let mut user = users
2102                .get_mut(&id)
2103                .ok_or_else(|| anyhow!("user not found"))?;
2104            user.connected_once = connected_once;
2105            Ok(())
2106        }
2107
2108        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2109            unimplemented!()
2110        }
2111
2112        // signups
2113
2114        async fn create_signup(&self, _signup: Signup) -> Result<()> {
2115            unimplemented!()
2116        }
2117
2118        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
2119            unimplemented!()
2120        }
2121
2122        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
2123            unimplemented!()
2124        }
2125
2126        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2127            unimplemented!()
2128        }
2129
2130        async fn create_user_from_invite(
2131            &self,
2132            _invite: &Invite,
2133            _user: NewUserParams,
2134        ) -> Result<Option<NewUserResult>> {
2135            unimplemented!()
2136        }
2137
2138        // invite codes
2139
2140        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2141            unimplemented!()
2142        }
2143
2144        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2145            self.background.simulate_random_delay().await;
2146            Ok(None)
2147        }
2148
2149        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2150            unimplemented!()
2151        }
2152
2153        async fn create_invite_from_code(
2154            &self,
2155            _code: &str,
2156            _email_address: &str,
2157            _device_id: Option<&str>,
2158        ) -> Result<Invite> {
2159            unimplemented!()
2160        }
2161
2162        // projects
2163
2164        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2165            self.background.simulate_random_delay().await;
2166            if !self.users.lock().contains_key(&host_user_id) {
2167                Err(anyhow!("no such user"))?;
2168            }
2169
2170            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2171            self.projects.lock().insert(
2172                project_id,
2173                Project {
2174                    id: project_id,
2175                    host_user_id,
2176                    unregistered: false,
2177                },
2178            );
2179            Ok(project_id)
2180        }
2181
2182        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2183            self.background.simulate_random_delay().await;
2184            self.projects
2185                .lock()
2186                .get_mut(&project_id)
2187                .ok_or_else(|| anyhow!("no such project"))?
2188                .unregistered = true;
2189            Ok(())
2190        }
2191
2192        async fn update_worktree_extensions(
2193            &self,
2194            project_id: ProjectId,
2195            worktree_id: u64,
2196            extensions: HashMap<String, u32>,
2197        ) -> Result<()> {
2198            self.background.simulate_random_delay().await;
2199            if !self.projects.lock().contains_key(&project_id) {
2200                Err(anyhow!("no such project"))?;
2201            }
2202
2203            for (extension, count) in extensions {
2204                self.worktree_extensions
2205                    .lock()
2206                    .insert((project_id, worktree_id, extension), count);
2207            }
2208
2209            Ok(())
2210        }
2211
2212        async fn get_project_extensions(
2213            &self,
2214            _project_id: ProjectId,
2215        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2216            unimplemented!()
2217        }
2218
2219        async fn record_user_activity(
2220            &self,
2221            _time_period: Range<OffsetDateTime>,
2222            _active_projects: &[(UserId, ProjectId)],
2223        ) -> Result<()> {
2224            unimplemented!()
2225        }
2226
2227        async fn get_active_user_count(
2228            &self,
2229            _time_period: Range<OffsetDateTime>,
2230            _min_duration: Duration,
2231            _only_collaborative: bool,
2232        ) -> Result<usize> {
2233            unimplemented!()
2234        }
2235
2236        async fn get_top_users_activity_summary(
2237            &self,
2238            _time_period: Range<OffsetDateTime>,
2239            _limit: usize,
2240        ) -> Result<Vec<UserActivitySummary>> {
2241            unimplemented!()
2242        }
2243
2244        async fn get_user_activity_timeline(
2245            &self,
2246            _time_period: Range<OffsetDateTime>,
2247            _user_id: UserId,
2248        ) -> Result<Vec<UserActivityPeriod>> {
2249            unimplemented!()
2250        }
2251
2252        // contacts
2253
2254        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2255            self.background.simulate_random_delay().await;
2256            let mut contacts = Vec::new();
2257
2258            for contact in self.contacts.lock().iter() {
2259                if contact.requester_id == id {
2260                    if contact.accepted {
2261                        contacts.push(Contact::Accepted {
2262                            user_id: contact.responder_id,
2263                            should_notify: contact.should_notify,
2264                        });
2265                    } else {
2266                        contacts.push(Contact::Outgoing {
2267                            user_id: contact.responder_id,
2268                        });
2269                    }
2270                } else if contact.responder_id == id {
2271                    if contact.accepted {
2272                        contacts.push(Contact::Accepted {
2273                            user_id: contact.requester_id,
2274                            should_notify: false,
2275                        });
2276                    } else {
2277                        contacts.push(Contact::Incoming {
2278                            user_id: contact.requester_id,
2279                            should_notify: contact.should_notify,
2280                        });
2281                    }
2282                }
2283            }
2284
2285            contacts.sort_unstable_by_key(|contact| contact.user_id());
2286            Ok(contacts)
2287        }
2288
2289        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2290            self.background.simulate_random_delay().await;
2291            Ok(self.contacts.lock().iter().any(|contact| {
2292                contact.accepted
2293                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2294                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2295            }))
2296        }
2297
2298        async fn send_contact_request(
2299            &self,
2300            requester_id: UserId,
2301            responder_id: UserId,
2302        ) -> Result<()> {
2303            self.background.simulate_random_delay().await;
2304            let mut contacts = self.contacts.lock();
2305            for contact in contacts.iter_mut() {
2306                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2307                    if contact.accepted {
2308                        Err(anyhow!("contact already exists"))?;
2309                    } else {
2310                        Err(anyhow!("contact already requested"))?;
2311                    }
2312                }
2313                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2314                    if contact.accepted {
2315                        Err(anyhow!("contact already exists"))?;
2316                    } else {
2317                        contact.accepted = true;
2318                        contact.should_notify = false;
2319                        return Ok(());
2320                    }
2321                }
2322            }
2323            contacts.push(FakeContact {
2324                requester_id,
2325                responder_id,
2326                accepted: false,
2327                should_notify: true,
2328            });
2329            Ok(())
2330        }
2331
2332        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2333            self.background.simulate_random_delay().await;
2334            self.contacts.lock().retain(|contact| {
2335                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2336            });
2337            Ok(())
2338        }
2339
2340        async fn dismiss_contact_notification(
2341            &self,
2342            user_id: UserId,
2343            contact_user_id: UserId,
2344        ) -> Result<()> {
2345            self.background.simulate_random_delay().await;
2346            let mut contacts = self.contacts.lock();
2347            for contact in contacts.iter_mut() {
2348                if contact.requester_id == contact_user_id
2349                    && contact.responder_id == user_id
2350                    && !contact.accepted
2351                {
2352                    contact.should_notify = false;
2353                    return Ok(());
2354                }
2355                if contact.requester_id == user_id
2356                    && contact.responder_id == contact_user_id
2357                    && contact.accepted
2358                {
2359                    contact.should_notify = false;
2360                    return Ok(());
2361                }
2362            }
2363            Err(anyhow!("no such notification"))?
2364        }
2365
2366        async fn respond_to_contact_request(
2367            &self,
2368            responder_id: UserId,
2369            requester_id: UserId,
2370            accept: bool,
2371        ) -> Result<()> {
2372            self.background.simulate_random_delay().await;
2373            let mut contacts = self.contacts.lock();
2374            for (ix, contact) in contacts.iter_mut().enumerate() {
2375                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2376                    if contact.accepted {
2377                        Err(anyhow!("contact already confirmed"))?;
2378                    }
2379                    if accept {
2380                        contact.accepted = true;
2381                        contact.should_notify = true;
2382                    } else {
2383                        contacts.remove(ix);
2384                    }
2385                    return Ok(());
2386                }
2387            }
2388            Err(anyhow!("no such contact request"))?
2389        }
2390
2391        async fn create_access_token_hash(
2392            &self,
2393            _user_id: UserId,
2394            _access_token_hash: &str,
2395            _max_access_token_count: usize,
2396        ) -> Result<()> {
2397            unimplemented!()
2398        }
2399
2400        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2401            unimplemented!()
2402        }
2403
2404        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2405            unimplemented!()
2406        }
2407
2408        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2409            self.background.simulate_random_delay().await;
2410            let mut orgs = self.orgs.lock();
2411            if orgs.values().any(|org| org.slug == slug) {
2412                Err(anyhow!("org already exists"))?
2413            } else {
2414                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2415                orgs.insert(
2416                    org_id,
2417                    Org {
2418                        id: org_id,
2419                        name: name.to_string(),
2420                        slug: slug.to_string(),
2421                    },
2422                );
2423                Ok(org_id)
2424            }
2425        }
2426
2427        async fn add_org_member(
2428            &self,
2429            org_id: OrgId,
2430            user_id: UserId,
2431            is_admin: bool,
2432        ) -> Result<()> {
2433            self.background.simulate_random_delay().await;
2434            if !self.orgs.lock().contains_key(&org_id) {
2435                Err(anyhow!("org does not exist"))?;
2436            }
2437            if !self.users.lock().contains_key(&user_id) {
2438                Err(anyhow!("user does not exist"))?;
2439            }
2440
2441            self.org_memberships
2442                .lock()
2443                .entry((org_id, user_id))
2444                .or_insert(is_admin);
2445            Ok(())
2446        }
2447
2448        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2449            self.background.simulate_random_delay().await;
2450            if !self.orgs.lock().contains_key(&org_id) {
2451                Err(anyhow!("org does not exist"))?;
2452            }
2453
2454            let mut channels = self.channels.lock();
2455            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2456            channels.insert(
2457                channel_id,
2458                Channel {
2459                    id: channel_id,
2460                    name: name.to_string(),
2461                    owner_id: org_id.0,
2462                    owner_is_user: false,
2463                },
2464            );
2465            Ok(channel_id)
2466        }
2467
2468        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2469            self.background.simulate_random_delay().await;
2470            Ok(self
2471                .channels
2472                .lock()
2473                .values()
2474                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2475                .cloned()
2476                .collect())
2477        }
2478
2479        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2480            self.background.simulate_random_delay().await;
2481            let channels = self.channels.lock();
2482            let memberships = self.channel_memberships.lock();
2483            Ok(channels
2484                .values()
2485                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2486                .cloned()
2487                .collect())
2488        }
2489
2490        async fn can_user_access_channel(
2491            &self,
2492            user_id: UserId,
2493            channel_id: ChannelId,
2494        ) -> Result<bool> {
2495            self.background.simulate_random_delay().await;
2496            Ok(self
2497                .channel_memberships
2498                .lock()
2499                .contains_key(&(channel_id, user_id)))
2500        }
2501
2502        async fn add_channel_member(
2503            &self,
2504            channel_id: ChannelId,
2505            user_id: UserId,
2506            is_admin: bool,
2507        ) -> Result<()> {
2508            self.background.simulate_random_delay().await;
2509            if !self.channels.lock().contains_key(&channel_id) {
2510                Err(anyhow!("channel does not exist"))?;
2511            }
2512            if !self.users.lock().contains_key(&user_id) {
2513                Err(anyhow!("user does not exist"))?;
2514            }
2515
2516            self.channel_memberships
2517                .lock()
2518                .entry((channel_id, user_id))
2519                .or_insert(is_admin);
2520            Ok(())
2521        }
2522
2523        async fn create_channel_message(
2524            &self,
2525            channel_id: ChannelId,
2526            sender_id: UserId,
2527            body: &str,
2528            timestamp: OffsetDateTime,
2529            nonce: u128,
2530        ) -> Result<MessageId> {
2531            self.background.simulate_random_delay().await;
2532            if !self.channels.lock().contains_key(&channel_id) {
2533                Err(anyhow!("channel does not exist"))?;
2534            }
2535            if !self.users.lock().contains_key(&sender_id) {
2536                Err(anyhow!("user does not exist"))?;
2537            }
2538
2539            let mut messages = self.channel_messages.lock();
2540            if let Some(message) = messages
2541                .values()
2542                .find(|message| message.nonce.as_u128() == nonce)
2543            {
2544                Ok(message.id)
2545            } else {
2546                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2547                messages.insert(
2548                    message_id,
2549                    ChannelMessage {
2550                        id: message_id,
2551                        channel_id,
2552                        sender_id,
2553                        body: body.to_string(),
2554                        sent_at: timestamp,
2555                        nonce: Uuid::from_u128(nonce),
2556                    },
2557                );
2558                Ok(message_id)
2559            }
2560        }
2561
2562        async fn get_channel_messages(
2563            &self,
2564            channel_id: ChannelId,
2565            count: usize,
2566            before_id: Option<MessageId>,
2567        ) -> Result<Vec<ChannelMessage>> {
2568            self.background.simulate_random_delay().await;
2569            let mut messages = self
2570                .channel_messages
2571                .lock()
2572                .values()
2573                .rev()
2574                .filter(|message| {
2575                    message.channel_id == channel_id
2576                        && message.id < before_id.unwrap_or(MessageId::MAX)
2577                })
2578                .take(count)
2579                .cloned()
2580                .collect::<Vec<_>>();
2581            messages.sort_unstable_by_key(|message| message.id);
2582            Ok(messages)
2583        }
2584
2585        async fn teardown(&self, _: &str) {}
2586
2587        #[cfg(test)]
2588        fn as_fake(&self) -> Option<&FakeDb> {
2589            Some(self)
2590        }
2591    }
2592
2593    pub struct TestDb {
2594        pub db: Option<Arc<dyn Db>>,
2595        pub url: String,
2596    }
2597
2598    impl TestDb {
2599        #[allow(clippy::await_holding_lock)]
2600        pub async fn real() -> Self {
2601            eprintln!("creating database...");
2602            let start = std::time::Instant::now();
2603            let mut rng = StdRng::from_entropy();
2604            let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
2605            Sqlite::create_database(&url).await.unwrap();
2606            let db = RealDb::new(&url, 5).await.unwrap();
2607            db.migrate(Path::new(TEST_MIGRATIONS_PATH.unwrap()), false)
2608                .await
2609                .unwrap();
2610
2611            eprintln!("created database: {:?}", start.elapsed());
2612            Self {
2613                db: Some(Arc::new(db)),
2614                url,
2615            }
2616        }
2617
2618        pub fn fake(background: Arc<Background>) -> Self {
2619            Self {
2620                db: Some(Arc::new(FakeDb::new(background))),
2621                url: Default::default(),
2622            }
2623        }
2624
2625        pub fn db(&self) -> &Arc<dyn Db> {
2626            self.db.as_ref().unwrap()
2627        }
2628    }
2629
2630    impl Drop for TestDb {
2631        fn drop(&mut self) {
2632            if let Some(db) = self.db.take() {
2633                std::fs::remove_file(&self.url).ok();
2634            }
2635        }
2636    }
2637}