db.rs

   1use crate::{Error, Result};
   2use anyhow::{anyhow, Context};
   3use async_trait::async_trait;
   4use axum::http::StatusCode;
   5use collections::HashMap;
   6use futures::StreamExt;
   7use serde::{Deserialize, Serialize};
   8pub use sqlx::postgres::PgPoolOptions as DbOptions;
   9use sqlx::{types::Uuid, FromRow, QueryBuilder};
  10use std::{cmp, ops::Range, time::Duration};
  11use time::{OffsetDateTime, PrimitiveDateTime};
  12
  13#[async_trait]
  14pub trait Db: Send + Sync {
  15    async fn create_user(
  16        &self,
  17        email_address: &str,
  18        admin: bool,
  19        params: NewUserParams,
  20    ) -> Result<UserId>;
  21    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  22    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  23    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  24    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  25    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  26    async fn get_user_by_github_account(
  27        &self,
  28        github_login: &str,
  29        github_user_id: Option<i32>,
  30    ) -> Result<Option<User>>;
  31    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  32    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  33    async fn destroy_user(&self, id: UserId) -> Result<()>;
  34
  35    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
  36    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  37    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  38    async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite>;
  39
  40    async fn create_signup(&self, signup: Signup) -> Result<()>;
  41    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
  42    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
  43    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
  44    async fn create_user_from_invite(
  45        &self,
  46        invite: &Invite,
  47        user: NewUserParams,
  48    ) -> Result<(UserId, Option<UserId>, String)>;
  49
  50    /// Registers a new project for the given user.
  51    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  52
  53    /// Unregisters a project for the given project id.
  54    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  55
  56    /// Update file counts by extension for the given project and worktree.
  57    async fn update_worktree_extensions(
  58        &self,
  59        project_id: ProjectId,
  60        worktree_id: u64,
  61        extensions: HashMap<String, u32>,
  62    ) -> Result<()>;
  63
  64    /// Get the file counts on the given project keyed by their worktree and extension.
  65    async fn get_project_extensions(
  66        &self,
  67        project_id: ProjectId,
  68    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  69
  70    /// Record which users have been active in which projects during
  71    /// a given period of time.
  72    async fn record_user_activity(
  73        &self,
  74        time_period: Range<OffsetDateTime>,
  75        active_projects: &[(UserId, ProjectId)],
  76    ) -> Result<()>;
  77
  78    /// Get the number of users who have been active in the given
  79    /// time period for at least the given time duration.
  80    async fn get_active_user_count(
  81        &self,
  82        time_period: Range<OffsetDateTime>,
  83        min_duration: Duration,
  84        only_collaborative: bool,
  85    ) -> Result<usize>;
  86
  87    /// Get the users that have been most active during the given time period,
  88    /// along with the amount of time they have been active in each project.
  89    async fn get_top_users_activity_summary(
  90        &self,
  91        time_period: Range<OffsetDateTime>,
  92        max_user_count: usize,
  93    ) -> Result<Vec<UserActivitySummary>>;
  94
  95    /// Get the project activity for the given user and time period.
  96    async fn get_user_activity_timeline(
  97        &self,
  98        time_period: Range<OffsetDateTime>,
  99        user_id: UserId,
 100    ) -> Result<Vec<UserActivityPeriod>>;
 101
 102    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
 103    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
 104    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 105    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 106    async fn dismiss_contact_notification(
 107        &self,
 108        responder_id: UserId,
 109        requester_id: UserId,
 110    ) -> Result<()>;
 111    async fn respond_to_contact_request(
 112        &self,
 113        responder_id: UserId,
 114        requester_id: UserId,
 115        accept: bool,
 116    ) -> Result<()>;
 117
 118    async fn create_access_token_hash(
 119        &self,
 120        user_id: UserId,
 121        access_token_hash: &str,
 122        max_access_token_count: usize,
 123    ) -> Result<()>;
 124    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 125
 126    #[cfg(any(test, feature = "seed-support"))]
 127    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 128    #[cfg(any(test, feature = "seed-support"))]
 129    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 130    #[cfg(any(test, feature = "seed-support"))]
 131    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 132    #[cfg(any(test, feature = "seed-support"))]
 133    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 134    #[cfg(any(test, feature = "seed-support"))]
 135
 136    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 137    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 138    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 139        -> Result<bool>;
 140
 141    #[cfg(any(test, feature = "seed-support"))]
 142    async fn add_channel_member(
 143        &self,
 144        channel_id: ChannelId,
 145        user_id: UserId,
 146        is_admin: bool,
 147    ) -> Result<()>;
 148    async fn create_channel_message(
 149        &self,
 150        channel_id: ChannelId,
 151        sender_id: UserId,
 152        body: &str,
 153        timestamp: OffsetDateTime,
 154        nonce: u128,
 155    ) -> Result<MessageId>;
 156    async fn get_channel_messages(
 157        &self,
 158        channel_id: ChannelId,
 159        count: usize,
 160        before_id: Option<MessageId>,
 161    ) -> Result<Vec<ChannelMessage>>;
 162
 163    #[cfg(test)]
 164    async fn teardown(&self, url: &str);
 165
 166    #[cfg(test)]
 167    fn as_fake(&self) -> Option<&FakeDb>;
 168}
 169
 170pub struct PostgresDb {
 171    pool: sqlx::PgPool,
 172}
 173
 174impl PostgresDb {
 175    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 176        let pool = DbOptions::new()
 177            .max_connections(max_connections)
 178            .connect(url)
 179            .await
 180            .context("failed to connect to postgres database")?;
 181        Ok(Self { pool })
 182    }
 183
 184    pub fn fuzzy_like_string(string: &str) -> String {
 185        let mut result = String::with_capacity(string.len() * 2 + 1);
 186        for c in string.chars() {
 187            if c.is_alphanumeric() {
 188                result.push('%');
 189                result.push(c);
 190            }
 191        }
 192        result.push('%');
 193        result
 194    }
 195}
 196
 197#[async_trait]
 198impl Db for PostgresDb {
 199    // users
 200
 201    async fn create_user(
 202        &self,
 203        email_address: &str,
 204        admin: bool,
 205        params: NewUserParams,
 206    ) -> Result<UserId> {
 207        let query = "
 208            INSERT INTO users (email_address, github_login, github_user_id, admin)
 209            VALUES ($1, $2, $3, $4)
 210            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 211            RETURNING id
 212        ";
 213        Ok(sqlx::query_scalar(query)
 214            .bind(email_address)
 215            .bind(params.github_login)
 216            .bind(params.github_user_id)
 217            .bind(admin)
 218            .fetch_one(&self.pool)
 219            .await
 220            .map(UserId)?)
 221    }
 222
 223    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 224        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 225        Ok(sqlx::query_as(query)
 226            .bind(limit as i32)
 227            .bind((page * limit) as i32)
 228            .fetch_all(&self.pool)
 229            .await?)
 230    }
 231
 232    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 233        let like_string = Self::fuzzy_like_string(name_query);
 234        let query = "
 235            SELECT users.*
 236            FROM users
 237            WHERE github_login ILIKE $1
 238            ORDER BY github_login <-> $2
 239            LIMIT $3
 240        ";
 241        Ok(sqlx::query_as(query)
 242            .bind(like_string)
 243            .bind(name_query)
 244            .bind(limit as i32)
 245            .fetch_all(&self.pool)
 246            .await?)
 247    }
 248
 249    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 250        let users = self.get_users_by_ids(vec![id]).await?;
 251        Ok(users.into_iter().next())
 252    }
 253
 254    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 255        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 256        let query = "
 257            SELECT users.*
 258            FROM users
 259            WHERE users.id = ANY ($1)
 260        ";
 261        Ok(sqlx::query_as(query)
 262            .bind(&ids)
 263            .fetch_all(&self.pool)
 264            .await?)
 265    }
 266
 267    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 268        let query = format!(
 269            "
 270            SELECT users.*
 271            FROM users
 272            WHERE invite_count = 0
 273            AND inviter_id IS{} NULL
 274            ",
 275            if invited_by_another_user { " NOT" } else { "" }
 276        );
 277
 278        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 279    }
 280
 281    async fn get_user_by_github_account(
 282        &self,
 283        github_login: &str,
 284        github_user_id: Option<i32>,
 285    ) -> Result<Option<User>> {
 286        if let Some(github_user_id) = github_user_id {
 287            let mut user = sqlx::query_as::<_, User>(
 288                "
 289                UPDATE users
 290                SET github_login = $1
 291                WHERE github_user_id = $2
 292                RETURNING *
 293                ",
 294            )
 295            .bind(github_login)
 296            .bind(github_user_id)
 297            .fetch_optional(&self.pool)
 298            .await?;
 299
 300            if user.is_none() {
 301                user = sqlx::query_as::<_, User>(
 302                    "
 303                    UPDATE users
 304                    SET github_user_id = $1
 305                    WHERE github_login = $2
 306                    RETURNING *
 307                    ",
 308                )
 309                .bind(github_user_id)
 310                .bind(github_login)
 311                .fetch_optional(&self.pool)
 312                .await?;
 313            }
 314
 315            Ok(user)
 316        } else {
 317            Ok(sqlx::query_as(
 318                "
 319                SELECT * FROM users
 320                WHERE github_login = $1
 321                LIMIT 1
 322                ",
 323            )
 324            .bind(github_login)
 325            .fetch_optional(&self.pool)
 326            .await?)
 327        }
 328    }
 329
 330    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 331        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 332        Ok(sqlx::query(query)
 333            .bind(is_admin)
 334            .bind(id.0)
 335            .execute(&self.pool)
 336            .await
 337            .map(drop)?)
 338    }
 339
 340    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 341        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 342        Ok(sqlx::query(query)
 343            .bind(connected_once)
 344            .bind(id.0)
 345            .execute(&self.pool)
 346            .await
 347            .map(drop)?)
 348    }
 349
 350    async fn destroy_user(&self, id: UserId) -> Result<()> {
 351        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 352        sqlx::query(query)
 353            .bind(id.0)
 354            .execute(&self.pool)
 355            .await
 356            .map(drop)?;
 357        let query = "DELETE FROM users WHERE id = $1;";
 358        Ok(sqlx::query(query)
 359            .bind(id.0)
 360            .execute(&self.pool)
 361            .await
 362            .map(drop)?)
 363    }
 364
 365    // signups
 366
 367    async fn create_signup(&self, signup: Signup) -> Result<()> {
 368        sqlx::query(
 369            "
 370            INSERT INTO signups
 371            (
 372                email_address,
 373                email_confirmation_code,
 374                email_confirmation_sent,
 375                platform_linux,
 376                platform_mac,
 377                platform_windows,
 378                platform_unknown,
 379                editor_features,
 380                programming_languages,
 381                device_id
 382            )
 383            VALUES
 384                ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
 385            RETURNING id
 386            ",
 387        )
 388        .bind(&signup.email_address)
 389        .bind(&random_email_confirmation_code())
 390        .bind(&signup.platform_linux)
 391        .bind(&signup.platform_mac)
 392        .bind(&signup.platform_windows)
 393        .bind(&signup.editor_features)
 394        .bind(&signup.programming_languages)
 395        .bind(&signup.device_id)
 396        .execute(&self.pool)
 397        .await?;
 398        Ok(())
 399    }
 400
 401    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 402        Ok(sqlx::query_as(
 403            "
 404            SELECT
 405                COUNT(*) as count,
 406                COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 407                COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 408                COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count
 409            FROM (
 410                SELECT *
 411                FROM signups
 412                WHERE
 413                    NOT email_confirmation_sent
 414            ) AS unsent
 415            ",
 416        )
 417        .fetch_one(&self.pool)
 418        .await?)
 419    }
 420
 421    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 422        Ok(sqlx::query_as(
 423            "
 424            SELECT
 425                email_address, email_confirmation_code
 426            FROM signups
 427            WHERE
 428                NOT email_confirmation_sent AND
 429                platform_mac
 430            LIMIT $1
 431            ",
 432        )
 433        .bind(count as i32)
 434        .fetch_all(&self.pool)
 435        .await?)
 436    }
 437
 438    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 439        sqlx::query(
 440            "
 441            UPDATE signups
 442            SET email_confirmation_sent = 't'
 443            WHERE email_address = ANY ($1)
 444            ",
 445        )
 446        .bind(
 447            &invites
 448                .iter()
 449                .map(|s| s.email_address.as_str())
 450                .collect::<Vec<_>>(),
 451        )
 452        .execute(&self.pool)
 453        .await?;
 454        Ok(())
 455    }
 456
 457    async fn create_user_from_invite(
 458        &self,
 459        invite: &Invite,
 460        user: NewUserParams,
 461    ) -> Result<(UserId, Option<UserId>, String)> {
 462        let mut tx = self.pool.begin().await?;
 463
 464        let (signup_id, existing_user_id, inviting_user_id, device_id): (
 465            i32,
 466            Option<UserId>,
 467            Option<UserId>,
 468            String,
 469        ) = sqlx::query_as(
 470            "
 471            SELECT id, user_id, inviting_user_id, device_id
 472            FROM signups
 473            WHERE
 474                email_address = $1 AND
 475                email_confirmation_code = $2
 476            ",
 477        )
 478        .bind(&invite.email_address)
 479        .bind(&invite.email_confirmation_code)
 480        .fetch_optional(&mut tx)
 481        .await?
 482        .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 483
 484        if existing_user_id.is_some() {
 485            Err(Error::Http(
 486                StatusCode::UNPROCESSABLE_ENTITY,
 487                "invitation already redeemed".to_string(),
 488            ))?;
 489        }
 490
 491        let user_id: UserId = sqlx::query_scalar(
 492            "
 493            INSERT INTO users
 494            (email_address, github_login, github_user_id, admin, invite_count, invite_code)
 495            VALUES
 496            ($1, $2, $3, 'f', $4, $5)
 497            RETURNING id
 498            ",
 499        )
 500        .bind(&invite.email_address)
 501        .bind(&user.github_login)
 502        .bind(&user.github_user_id)
 503        .bind(&user.invite_count)
 504        .bind(random_invite_code())
 505        .fetch_one(&mut tx)
 506        .await?;
 507
 508        sqlx::query(
 509            "
 510            UPDATE signups
 511            SET user_id = $1
 512            WHERE id = $2
 513            ",
 514        )
 515        .bind(&user_id)
 516        .bind(&signup_id)
 517        .execute(&mut tx)
 518        .await?;
 519
 520        if let Some(inviting_user_id) = inviting_user_id {
 521            let id: Option<UserId> = sqlx::query_scalar(
 522                "
 523                UPDATE users
 524                SET invite_count = invite_count - 1
 525                WHERE id = $1 AND invite_count > 0
 526                RETURNING id
 527                ",
 528            )
 529            .bind(&inviting_user_id)
 530            .fetch_optional(&mut tx)
 531            .await?;
 532
 533            if id.is_none() {
 534                Err(Error::Http(
 535                    StatusCode::UNAUTHORIZED,
 536                    "no invites remaining".to_string(),
 537                ))?;
 538            }
 539
 540            sqlx::query(
 541                "
 542                INSERT INTO contacts
 543                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 544                VALUES
 545                    ($1, $2, 't', 't', 't')
 546                ",
 547            )
 548            .bind(inviting_user_id)
 549            .bind(user_id)
 550            .execute(&mut tx)
 551            .await?;
 552        }
 553
 554        tx.commit().await?;
 555        Ok((user_id, inviting_user_id, device_id))
 556    }
 557
 558    // invite codes
 559
 560    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 561        let mut tx = self.pool.begin().await?;
 562        if count > 0 {
 563            sqlx::query(
 564                "
 565                UPDATE users
 566                SET invite_code = $1
 567                WHERE id = $2 AND invite_code IS NULL
 568            ",
 569            )
 570            .bind(random_invite_code())
 571            .bind(id)
 572            .execute(&mut tx)
 573            .await?;
 574        }
 575
 576        sqlx::query(
 577            "
 578            UPDATE users
 579            SET invite_count = $1
 580            WHERE id = $2
 581            ",
 582        )
 583        .bind(count as i32)
 584        .bind(id)
 585        .execute(&mut tx)
 586        .await?;
 587        tx.commit().await?;
 588        Ok(())
 589    }
 590
 591    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 592        let result: Option<(String, i32)> = sqlx::query_as(
 593            "
 594                SELECT invite_code, invite_count
 595                FROM users
 596                WHERE id = $1 AND invite_code IS NOT NULL 
 597            ",
 598        )
 599        .bind(id)
 600        .fetch_optional(&self.pool)
 601        .await?;
 602        if let Some((code, count)) = result {
 603            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 604        } else {
 605            Ok(None)
 606        }
 607    }
 608
 609    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 610        sqlx::query_as(
 611            "
 612                SELECT *
 613                FROM users
 614                WHERE invite_code = $1
 615            ",
 616        )
 617        .bind(code)
 618        .fetch_optional(&self.pool)
 619        .await?
 620        .ok_or_else(|| {
 621            Error::Http(
 622                StatusCode::NOT_FOUND,
 623                "that invite code does not exist".to_string(),
 624            )
 625        })
 626    }
 627
 628    async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite> {
 629        let mut tx = self.pool.begin().await?;
 630
 631        let existing_user: Option<UserId> = sqlx::query_scalar(
 632            "
 633            SELECT id
 634            FROM users
 635            WHERE email_address = $1
 636            ",
 637        )
 638        .bind(email_address)
 639        .fetch_optional(&mut tx)
 640        .await?;
 641        if existing_user.is_some() {
 642            Err(anyhow!("email address is already in use"))?;
 643        }
 644
 645        let row: Option<(UserId, i32)> = sqlx::query_as(
 646            "
 647            SELECT id, invite_count
 648            FROM users
 649            WHERE invite_code = $1
 650            ",
 651        )
 652        .bind(code)
 653        .fetch_optional(&mut tx)
 654        .await?;
 655
 656        let (inviter_id, invite_count) = match row {
 657            Some(row) => row,
 658            None => Err(Error::Http(
 659                StatusCode::NOT_FOUND,
 660                "invite code not found".to_string(),
 661            ))?,
 662        };
 663
 664        if invite_count == 0 {
 665            Err(Error::Http(
 666                StatusCode::UNAUTHORIZED,
 667                "no invites remaining".to_string(),
 668            ))?;
 669        }
 670
 671        let email_confirmation_code: String = sqlx::query_scalar(
 672            "
 673            INSERT INTO signups
 674            (
 675                email_address,
 676                email_confirmation_code,
 677                email_confirmation_sent,
 678                inviting_user_id,
 679                platform_linux,
 680                platform_mac,
 681                platform_windows,
 682                platform_unknown
 683            )
 684            VALUES
 685                ($1, $2, 'f', $3, 'f', 'f', 'f', 't')
 686            ON CONFLICT (email_address)
 687            DO UPDATE SET
 688                inviting_user_id = excluded.inviting_user_id
 689            RETURNING email_confirmation_code
 690            ",
 691        )
 692        .bind(&email_address)
 693        .bind(&random_email_confirmation_code())
 694        .bind(&inviter_id)
 695        .fetch_one(&mut tx)
 696        .await?;
 697
 698        tx.commit().await?;
 699
 700        Ok(Invite {
 701            email_address: email_address.into(),
 702            email_confirmation_code,
 703        })
 704    }
 705
 706    // projects
 707
 708    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 709        Ok(sqlx::query_scalar(
 710            "
 711            INSERT INTO projects(host_user_id)
 712            VALUES ($1)
 713            RETURNING id
 714            ",
 715        )
 716        .bind(host_user_id)
 717        .fetch_one(&self.pool)
 718        .await
 719        .map(ProjectId)?)
 720    }
 721
 722    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 723        sqlx::query(
 724            "
 725            UPDATE projects
 726            SET unregistered = 't'
 727            WHERE id = $1
 728            ",
 729        )
 730        .bind(project_id)
 731        .execute(&self.pool)
 732        .await?;
 733        Ok(())
 734    }
 735
 736    async fn update_worktree_extensions(
 737        &self,
 738        project_id: ProjectId,
 739        worktree_id: u64,
 740        extensions: HashMap<String, u32>,
 741    ) -> Result<()> {
 742        if extensions.is_empty() {
 743            return Ok(());
 744        }
 745
 746        let mut query = QueryBuilder::new(
 747            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 748        );
 749        query.push_values(extensions, |mut query, (extension, count)| {
 750            query
 751                .push_bind(project_id)
 752                .push_bind(worktree_id as i32)
 753                .push_bind(extension)
 754                .push_bind(count as i32);
 755        });
 756        query.push(
 757            "
 758            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 759            count = excluded.count
 760            ",
 761        );
 762        query.build().execute(&self.pool).await?;
 763
 764        Ok(())
 765    }
 766
 767    async fn get_project_extensions(
 768        &self,
 769        project_id: ProjectId,
 770    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 771        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 772        struct WorktreeExtension {
 773            worktree_id: i32,
 774            extension: String,
 775            count: i32,
 776        }
 777
 778        let query = "
 779            SELECT worktree_id, extension, count
 780            FROM worktree_extensions
 781            WHERE project_id = $1
 782        ";
 783        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 784            .bind(&project_id)
 785            .fetch_all(&self.pool)
 786            .await?;
 787
 788        let mut extension_counts = HashMap::default();
 789        for count in counts {
 790            extension_counts
 791                .entry(count.worktree_id as u64)
 792                .or_insert_with(HashMap::default)
 793                .insert(count.extension, count.count as usize);
 794        }
 795        Ok(extension_counts)
 796    }
 797
 798    async fn record_user_activity(
 799        &self,
 800        time_period: Range<OffsetDateTime>,
 801        projects: &[(UserId, ProjectId)],
 802    ) -> Result<()> {
 803        let query = "
 804            INSERT INTO project_activity_periods
 805            (ended_at, duration_millis, user_id, project_id)
 806            VALUES
 807            ($1, $2, $3, $4);
 808        ";
 809
 810        let mut tx = self.pool.begin().await?;
 811        let duration_millis =
 812            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 813        for (user_id, project_id) in projects {
 814            sqlx::query(query)
 815                .bind(time_period.end)
 816                .bind(duration_millis)
 817                .bind(user_id)
 818                .bind(project_id)
 819                .execute(&mut tx)
 820                .await?;
 821        }
 822        tx.commit().await?;
 823
 824        Ok(())
 825    }
 826
 827    async fn get_active_user_count(
 828        &self,
 829        time_period: Range<OffsetDateTime>,
 830        min_duration: Duration,
 831        only_collaborative: bool,
 832    ) -> Result<usize> {
 833        let mut with_clause = String::new();
 834        with_clause.push_str("WITH\n");
 835        with_clause.push_str(
 836            "
 837            project_durations AS (
 838                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 839                FROM project_activity_periods
 840                WHERE $1 < ended_at AND ended_at <= $2
 841                GROUP BY user_id, project_id
 842            ),
 843            ",
 844        );
 845        with_clause.push_str(
 846            "
 847            project_collaborators as (
 848                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 849                FROM project_durations
 850                GROUP BY project_id
 851            ),
 852            ",
 853        );
 854
 855        if only_collaborative {
 856            with_clause.push_str(
 857                "
 858                user_durations AS (
 859                    SELECT user_id, SUM(project_duration) as total_duration
 860                    FROM project_durations, project_collaborators
 861                    WHERE
 862                        project_durations.project_id = project_collaborators.project_id AND
 863                        max_collaborators > 1
 864                    GROUP BY user_id
 865                    ORDER BY total_duration DESC
 866                    LIMIT $3
 867                )
 868                ",
 869            );
 870        } else {
 871            with_clause.push_str(
 872                "
 873                user_durations AS (
 874                    SELECT user_id, SUM(project_duration) as total_duration
 875                    FROM project_durations
 876                    GROUP BY user_id
 877                    ORDER BY total_duration DESC
 878                    LIMIT $3
 879                )
 880                ",
 881            );
 882        }
 883
 884        let query = format!(
 885            "
 886            {with_clause}
 887            SELECT count(user_durations.user_id)
 888            FROM user_durations
 889            WHERE user_durations.total_duration >= $3
 890            "
 891        );
 892
 893        let count: i64 = sqlx::query_scalar(&query)
 894            .bind(time_period.start)
 895            .bind(time_period.end)
 896            .bind(min_duration.as_millis() as i64)
 897            .fetch_one(&self.pool)
 898            .await?;
 899        Ok(count as usize)
 900    }
 901
 902    async fn get_top_users_activity_summary(
 903        &self,
 904        time_period: Range<OffsetDateTime>,
 905        max_user_count: usize,
 906    ) -> Result<Vec<UserActivitySummary>> {
 907        let query = "
 908            WITH
 909                project_durations AS (
 910                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 911                    FROM project_activity_periods
 912                    WHERE $1 < ended_at AND ended_at <= $2
 913                    GROUP BY user_id, project_id
 914                ),
 915                user_durations AS (
 916                    SELECT user_id, SUM(project_duration) as total_duration
 917                    FROM project_durations
 918                    GROUP BY user_id
 919                    ORDER BY total_duration DESC
 920                    LIMIT $3
 921                ),
 922                project_collaborators as (
 923                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 924                    FROM project_durations
 925                    GROUP BY project_id
 926                )
 927            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
 928            FROM user_durations, project_durations, project_collaborators, users
 929            WHERE
 930                user_durations.user_id = project_durations.user_id AND
 931                user_durations.user_id = users.id AND
 932                project_durations.project_id = project_collaborators.project_id
 933            ORDER BY total_duration DESC, user_id ASC, project_id ASC
 934        ";
 935
 936        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
 937            .bind(time_period.start)
 938            .bind(time_period.end)
 939            .bind(max_user_count as i32)
 940            .fetch(&self.pool);
 941
 942        let mut result = Vec::<UserActivitySummary>::new();
 943        while let Some(row) = rows.next().await {
 944            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
 945            let project_id = project_id;
 946            let duration = Duration::from_millis(duration_millis as u64);
 947            let project_activity = ProjectActivitySummary {
 948                id: project_id,
 949                duration,
 950                max_collaborators: project_collaborators as usize,
 951            };
 952            if let Some(last_summary) = result.last_mut() {
 953                if last_summary.id == user_id {
 954                    last_summary.project_activity.push(project_activity);
 955                    continue;
 956                }
 957            }
 958            result.push(UserActivitySummary {
 959                id: user_id,
 960                project_activity: vec![project_activity],
 961                github_login,
 962            });
 963        }
 964
 965        Ok(result)
 966    }
 967
 968    async fn get_user_activity_timeline(
 969        &self,
 970        time_period: Range<OffsetDateTime>,
 971        user_id: UserId,
 972    ) -> Result<Vec<UserActivityPeriod>> {
 973        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 974
 975        let query = "
 976            SELECT
 977                project_activity_periods.ended_at,
 978                project_activity_periods.duration_millis,
 979                project_activity_periods.project_id,
 980                worktree_extensions.extension,
 981                worktree_extensions.count
 982            FROM project_activity_periods
 983            LEFT OUTER JOIN
 984                worktree_extensions
 985            ON
 986                project_activity_periods.project_id = worktree_extensions.project_id
 987            WHERE
 988                project_activity_periods.user_id = $1 AND
 989                $2 < project_activity_periods.ended_at AND
 990                project_activity_periods.ended_at <= $3
 991            ORDER BY project_activity_periods.id ASC
 992        ";
 993
 994        let mut rows = sqlx::query_as::<
 995            _,
 996            (
 997                PrimitiveDateTime,
 998                i32,
 999                ProjectId,
1000                Option<String>,
1001                Option<i32>,
1002            ),
1003        >(query)
1004        .bind(user_id)
1005        .bind(time_period.start)
1006        .bind(time_period.end)
1007        .fetch(&self.pool);
1008
1009        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1010        while let Some(row) = rows.next().await {
1011            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1012            let ended_at = ended_at.assume_utc();
1013            let duration = Duration::from_millis(duration_millis as u64);
1014            let started_at = ended_at - duration;
1015            let project_time_periods = time_periods.entry(project_id).or_default();
1016
1017            if let Some(prev_duration) = project_time_periods.last_mut() {
1018                if started_at <= prev_duration.end + COALESCE_THRESHOLD
1019                    && ended_at >= prev_duration.start
1020                {
1021                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
1022                } else {
1023                    project_time_periods.push(UserActivityPeriod {
1024                        project_id,
1025                        start: started_at,
1026                        end: ended_at,
1027                        extensions: Default::default(),
1028                    });
1029                }
1030            } else {
1031                project_time_periods.push(UserActivityPeriod {
1032                    project_id,
1033                    start: started_at,
1034                    end: ended_at,
1035                    extensions: Default::default(),
1036                });
1037            }
1038
1039            if let Some((extension, extension_count)) = extension.zip(extension_count) {
1040                project_time_periods
1041                    .last_mut()
1042                    .unwrap()
1043                    .extensions
1044                    .insert(extension, extension_count as usize);
1045            }
1046        }
1047
1048        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1049        durations.sort_unstable_by_key(|duration| duration.start);
1050        Ok(durations)
1051    }
1052
1053    // contacts
1054
1055    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1056        let query = "
1057            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1058            FROM contacts
1059            WHERE user_id_a = $1 OR user_id_b = $1;
1060        ";
1061
1062        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1063            .bind(user_id)
1064            .fetch(&self.pool);
1065
1066        let mut contacts = vec![Contact::Accepted {
1067            user_id,
1068            should_notify: false,
1069        }];
1070        while let Some(row) = rows.next().await {
1071            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1072
1073            if user_id_a == user_id {
1074                if accepted {
1075                    contacts.push(Contact::Accepted {
1076                        user_id: user_id_b,
1077                        should_notify: should_notify && a_to_b,
1078                    });
1079                } else if a_to_b {
1080                    contacts.push(Contact::Outgoing { user_id: user_id_b })
1081                } else {
1082                    contacts.push(Contact::Incoming {
1083                        user_id: user_id_b,
1084                        should_notify,
1085                    });
1086                }
1087            } else if accepted {
1088                contacts.push(Contact::Accepted {
1089                    user_id: user_id_a,
1090                    should_notify: should_notify && !a_to_b,
1091                });
1092            } else if a_to_b {
1093                contacts.push(Contact::Incoming {
1094                    user_id: user_id_a,
1095                    should_notify,
1096                });
1097            } else {
1098                contacts.push(Contact::Outgoing { user_id: user_id_a });
1099            }
1100        }
1101
1102        contacts.sort_unstable_by_key(|contact| contact.user_id());
1103
1104        Ok(contacts)
1105    }
1106
1107    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1108        let (id_a, id_b) = if user_id_1 < user_id_2 {
1109            (user_id_1, user_id_2)
1110        } else {
1111            (user_id_2, user_id_1)
1112        };
1113
1114        let query = "
1115            SELECT 1 FROM contacts
1116            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1117            LIMIT 1
1118        ";
1119        Ok(sqlx::query_scalar::<_, i32>(query)
1120            .bind(id_a.0)
1121            .bind(id_b.0)
1122            .fetch_optional(&self.pool)
1123            .await?
1124            .is_some())
1125    }
1126
1127    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1128        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1129            (sender_id, receiver_id, true)
1130        } else {
1131            (receiver_id, sender_id, false)
1132        };
1133        let query = "
1134            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1135            VALUES ($1, $2, $3, 'f', 't')
1136            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1137            SET
1138                accepted = 't',
1139                should_notify = 'f'
1140            WHERE
1141                NOT contacts.accepted AND
1142                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1143                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1144        ";
1145        let result = sqlx::query(query)
1146            .bind(id_a.0)
1147            .bind(id_b.0)
1148            .bind(a_to_b)
1149            .execute(&self.pool)
1150            .await?;
1151
1152        if result.rows_affected() == 1 {
1153            Ok(())
1154        } else {
1155            Err(anyhow!("contact already requested"))?
1156        }
1157    }
1158
1159    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1160        let (id_a, id_b) = if responder_id < requester_id {
1161            (responder_id, requester_id)
1162        } else {
1163            (requester_id, responder_id)
1164        };
1165        let query = "
1166            DELETE FROM contacts
1167            WHERE user_id_a = $1 AND user_id_b = $2;
1168        ";
1169        let result = sqlx::query(query)
1170            .bind(id_a.0)
1171            .bind(id_b.0)
1172            .execute(&self.pool)
1173            .await?;
1174
1175        if result.rows_affected() == 1 {
1176            Ok(())
1177        } else {
1178            Err(anyhow!("no such contact"))?
1179        }
1180    }
1181
1182    async fn dismiss_contact_notification(
1183        &self,
1184        user_id: UserId,
1185        contact_user_id: UserId,
1186    ) -> Result<()> {
1187        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1188            (user_id, contact_user_id, true)
1189        } else {
1190            (contact_user_id, user_id, false)
1191        };
1192
1193        let query = "
1194            UPDATE contacts
1195            SET should_notify = 'f'
1196            WHERE
1197                user_id_a = $1 AND user_id_b = $2 AND
1198                (
1199                    (a_to_b = $3 AND accepted) OR
1200                    (a_to_b != $3 AND NOT accepted)
1201                );
1202        ";
1203
1204        let result = sqlx::query(query)
1205            .bind(id_a.0)
1206            .bind(id_b.0)
1207            .bind(a_to_b)
1208            .execute(&self.pool)
1209            .await?;
1210
1211        if result.rows_affected() == 0 {
1212            Err(anyhow!("no such contact request"))?;
1213        }
1214
1215        Ok(())
1216    }
1217
1218    async fn respond_to_contact_request(
1219        &self,
1220        responder_id: UserId,
1221        requester_id: UserId,
1222        accept: bool,
1223    ) -> Result<()> {
1224        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1225            (responder_id, requester_id, false)
1226        } else {
1227            (requester_id, responder_id, true)
1228        };
1229        let result = if accept {
1230            let query = "
1231                UPDATE contacts
1232                SET accepted = 't', should_notify = 't'
1233                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1234            ";
1235            sqlx::query(query)
1236                .bind(id_a.0)
1237                .bind(id_b.0)
1238                .bind(a_to_b)
1239                .execute(&self.pool)
1240                .await?
1241        } else {
1242            let query = "
1243                DELETE FROM contacts
1244                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1245            ";
1246            sqlx::query(query)
1247                .bind(id_a.0)
1248                .bind(id_b.0)
1249                .bind(a_to_b)
1250                .execute(&self.pool)
1251                .await?
1252        };
1253        if result.rows_affected() == 1 {
1254            Ok(())
1255        } else {
1256            Err(anyhow!("no such contact request"))?
1257        }
1258    }
1259
1260    // access tokens
1261
1262    async fn create_access_token_hash(
1263        &self,
1264        user_id: UserId,
1265        access_token_hash: &str,
1266        max_access_token_count: usize,
1267    ) -> Result<()> {
1268        let insert_query = "
1269            INSERT INTO access_tokens (user_id, hash)
1270            VALUES ($1, $2);
1271        ";
1272        let cleanup_query = "
1273            DELETE FROM access_tokens
1274            WHERE id IN (
1275                SELECT id from access_tokens
1276                WHERE user_id = $1
1277                ORDER BY id DESC
1278                OFFSET $3
1279            )
1280        ";
1281
1282        let mut tx = self.pool.begin().await?;
1283        sqlx::query(insert_query)
1284            .bind(user_id.0)
1285            .bind(access_token_hash)
1286            .execute(&mut tx)
1287            .await?;
1288        sqlx::query(cleanup_query)
1289            .bind(user_id.0)
1290            .bind(access_token_hash)
1291            .bind(max_access_token_count as i32)
1292            .execute(&mut tx)
1293            .await?;
1294        Ok(tx.commit().await?)
1295    }
1296
1297    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1298        let query = "
1299            SELECT hash
1300            FROM access_tokens
1301            WHERE user_id = $1
1302            ORDER BY id DESC
1303        ";
1304        Ok(sqlx::query_scalar(query)
1305            .bind(user_id.0)
1306            .fetch_all(&self.pool)
1307            .await?)
1308    }
1309
1310    // orgs
1311
1312    #[allow(unused)] // Help rust-analyzer
1313    #[cfg(any(test, feature = "seed-support"))]
1314    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1315        let query = "
1316            SELECT *
1317            FROM orgs
1318            WHERE slug = $1
1319        ";
1320        Ok(sqlx::query_as(query)
1321            .bind(slug)
1322            .fetch_optional(&self.pool)
1323            .await?)
1324    }
1325
1326    #[cfg(any(test, feature = "seed-support"))]
1327    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1328        let query = "
1329            INSERT INTO orgs (name, slug)
1330            VALUES ($1, $2)
1331            RETURNING id
1332        ";
1333        Ok(sqlx::query_scalar(query)
1334            .bind(name)
1335            .bind(slug)
1336            .fetch_one(&self.pool)
1337            .await
1338            .map(OrgId)?)
1339    }
1340
1341    #[cfg(any(test, feature = "seed-support"))]
1342    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1343        let query = "
1344            INSERT INTO org_memberships (org_id, user_id, admin)
1345            VALUES ($1, $2, $3)
1346            ON CONFLICT DO NOTHING
1347        ";
1348        Ok(sqlx::query(query)
1349            .bind(org_id.0)
1350            .bind(user_id.0)
1351            .bind(is_admin)
1352            .execute(&self.pool)
1353            .await
1354            .map(drop)?)
1355    }
1356
1357    // channels
1358
1359    #[cfg(any(test, feature = "seed-support"))]
1360    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1361        let query = "
1362            INSERT INTO channels (owner_id, owner_is_user, name)
1363            VALUES ($1, false, $2)
1364            RETURNING id
1365        ";
1366        Ok(sqlx::query_scalar(query)
1367            .bind(org_id.0)
1368            .bind(name)
1369            .fetch_one(&self.pool)
1370            .await
1371            .map(ChannelId)?)
1372    }
1373
1374    #[allow(unused)] // Help rust-analyzer
1375    #[cfg(any(test, feature = "seed-support"))]
1376    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1377        let query = "
1378            SELECT *
1379            FROM channels
1380            WHERE
1381                channels.owner_is_user = false AND
1382                channels.owner_id = $1
1383        ";
1384        Ok(sqlx::query_as(query)
1385            .bind(org_id.0)
1386            .fetch_all(&self.pool)
1387            .await?)
1388    }
1389
1390    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1391        let query = "
1392            SELECT
1393                channels.*
1394            FROM
1395                channel_memberships, channels
1396            WHERE
1397                channel_memberships.user_id = $1 AND
1398                channel_memberships.channel_id = channels.id
1399        ";
1400        Ok(sqlx::query_as(query)
1401            .bind(user_id.0)
1402            .fetch_all(&self.pool)
1403            .await?)
1404    }
1405
1406    async fn can_user_access_channel(
1407        &self,
1408        user_id: UserId,
1409        channel_id: ChannelId,
1410    ) -> Result<bool> {
1411        let query = "
1412            SELECT id
1413            FROM channel_memberships
1414            WHERE user_id = $1 AND channel_id = $2
1415            LIMIT 1
1416        ";
1417        Ok(sqlx::query_scalar::<_, i32>(query)
1418            .bind(user_id.0)
1419            .bind(channel_id.0)
1420            .fetch_optional(&self.pool)
1421            .await
1422            .map(|e| e.is_some())?)
1423    }
1424
1425    #[cfg(any(test, feature = "seed-support"))]
1426    async fn add_channel_member(
1427        &self,
1428        channel_id: ChannelId,
1429        user_id: UserId,
1430        is_admin: bool,
1431    ) -> Result<()> {
1432        let query = "
1433            INSERT INTO channel_memberships (channel_id, user_id, admin)
1434            VALUES ($1, $2, $3)
1435            ON CONFLICT DO NOTHING
1436        ";
1437        Ok(sqlx::query(query)
1438            .bind(channel_id.0)
1439            .bind(user_id.0)
1440            .bind(is_admin)
1441            .execute(&self.pool)
1442            .await
1443            .map(drop)?)
1444    }
1445
1446    // messages
1447
1448    async fn create_channel_message(
1449        &self,
1450        channel_id: ChannelId,
1451        sender_id: UserId,
1452        body: &str,
1453        timestamp: OffsetDateTime,
1454        nonce: u128,
1455    ) -> Result<MessageId> {
1456        let query = "
1457            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1458            VALUES ($1, $2, $3, $4, $5)
1459            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1460            RETURNING id
1461        ";
1462        Ok(sqlx::query_scalar(query)
1463            .bind(channel_id.0)
1464            .bind(sender_id.0)
1465            .bind(body)
1466            .bind(timestamp)
1467            .bind(Uuid::from_u128(nonce))
1468            .fetch_one(&self.pool)
1469            .await
1470            .map(MessageId)?)
1471    }
1472
1473    async fn get_channel_messages(
1474        &self,
1475        channel_id: ChannelId,
1476        count: usize,
1477        before_id: Option<MessageId>,
1478    ) -> Result<Vec<ChannelMessage>> {
1479        let query = r#"
1480            SELECT * FROM (
1481                SELECT
1482                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1483                FROM
1484                    channel_messages
1485                WHERE
1486                    channel_id = $1 AND
1487                    id < $2
1488                ORDER BY id DESC
1489                LIMIT $3
1490            ) as recent_messages
1491            ORDER BY id ASC
1492        "#;
1493        Ok(sqlx::query_as(query)
1494            .bind(channel_id.0)
1495            .bind(before_id.unwrap_or(MessageId::MAX))
1496            .bind(count as i64)
1497            .fetch_all(&self.pool)
1498            .await?)
1499    }
1500
1501    #[cfg(test)]
1502    async fn teardown(&self, url: &str) {
1503        use util::ResultExt;
1504
1505        let query = "
1506            SELECT pg_terminate_backend(pg_stat_activity.pid)
1507            FROM pg_stat_activity
1508            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1509        ";
1510        sqlx::query(query).execute(&self.pool).await.log_err();
1511        self.pool.close().await;
1512        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1513            .await
1514            .log_err();
1515    }
1516
1517    #[cfg(test)]
1518    fn as_fake(&self) -> Option<&FakeDb> {
1519        None
1520    }
1521}
1522
1523macro_rules! id_type {
1524    ($name:ident) => {
1525        #[derive(
1526            Clone,
1527            Copy,
1528            Debug,
1529            Default,
1530            PartialEq,
1531            Eq,
1532            PartialOrd,
1533            Ord,
1534            Hash,
1535            sqlx::Type,
1536            Serialize,
1537            Deserialize,
1538        )]
1539        #[sqlx(transparent)]
1540        #[serde(transparent)]
1541        pub struct $name(pub i32);
1542
1543        impl $name {
1544            #[allow(unused)]
1545            pub const MAX: Self = Self(i32::MAX);
1546
1547            #[allow(unused)]
1548            pub fn from_proto(value: u64) -> Self {
1549                Self(value as i32)
1550            }
1551
1552            #[allow(unused)]
1553            pub fn to_proto(self) -> u64 {
1554                self.0 as u64
1555            }
1556        }
1557
1558        impl std::fmt::Display for $name {
1559            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1560                self.0.fmt(f)
1561            }
1562        }
1563    };
1564}
1565
1566id_type!(UserId);
1567#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1568pub struct User {
1569    pub id: UserId,
1570    pub github_login: String,
1571    pub github_user_id: Option<i32>,
1572    pub email_address: Option<String>,
1573    pub admin: bool,
1574    pub invite_code: Option<String>,
1575    pub invite_count: i32,
1576    pub connected_once: bool,
1577}
1578
1579id_type!(ProjectId);
1580#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1581pub struct Project {
1582    pub id: ProjectId,
1583    pub host_user_id: UserId,
1584    pub unregistered: bool,
1585}
1586
1587#[derive(Clone, Debug, PartialEq, Serialize)]
1588pub struct UserActivitySummary {
1589    pub id: UserId,
1590    pub github_login: String,
1591    pub project_activity: Vec<ProjectActivitySummary>,
1592}
1593
1594#[derive(Clone, Debug, PartialEq, Serialize)]
1595pub struct ProjectActivitySummary {
1596    pub id: ProjectId,
1597    pub duration: Duration,
1598    pub max_collaborators: usize,
1599}
1600
1601#[derive(Clone, Debug, PartialEq, Serialize)]
1602pub struct UserActivityPeriod {
1603    pub project_id: ProjectId,
1604    #[serde(with = "time::serde::iso8601")]
1605    pub start: OffsetDateTime,
1606    #[serde(with = "time::serde::iso8601")]
1607    pub end: OffsetDateTime,
1608    pub extensions: HashMap<String, usize>,
1609}
1610
1611id_type!(OrgId);
1612#[derive(FromRow)]
1613pub struct Org {
1614    pub id: OrgId,
1615    pub name: String,
1616    pub slug: String,
1617}
1618
1619id_type!(ChannelId);
1620#[derive(Clone, Debug, FromRow, Serialize)]
1621pub struct Channel {
1622    pub id: ChannelId,
1623    pub name: String,
1624    pub owner_id: i32,
1625    pub owner_is_user: bool,
1626}
1627
1628id_type!(MessageId);
1629#[derive(Clone, Debug, FromRow)]
1630pub struct ChannelMessage {
1631    pub id: MessageId,
1632    pub channel_id: ChannelId,
1633    pub sender_id: UserId,
1634    pub body: String,
1635    pub sent_at: OffsetDateTime,
1636    pub nonce: Uuid,
1637}
1638
1639#[derive(Clone, Debug, PartialEq, Eq)]
1640pub enum Contact {
1641    Accepted {
1642        user_id: UserId,
1643        should_notify: bool,
1644    },
1645    Outgoing {
1646        user_id: UserId,
1647    },
1648    Incoming {
1649        user_id: UserId,
1650        should_notify: bool,
1651    },
1652}
1653
1654impl Contact {
1655    pub fn user_id(&self) -> UserId {
1656        match self {
1657            Contact::Accepted { user_id, .. } => *user_id,
1658            Contact::Outgoing { user_id } => *user_id,
1659            Contact::Incoming { user_id, .. } => *user_id,
1660        }
1661    }
1662}
1663
1664#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1665pub struct IncomingContactRequest {
1666    pub requester_id: UserId,
1667    pub should_notify: bool,
1668}
1669
1670#[derive(Clone, Deserialize)]
1671pub struct Signup {
1672    pub email_address: String,
1673    pub platform_mac: bool,
1674    pub platform_windows: bool,
1675    pub platform_linux: bool,
1676    pub editor_features: Vec<String>,
1677    pub programming_languages: Vec<String>,
1678    pub device_id: String,
1679}
1680
1681#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1682pub struct WaitlistSummary {
1683    #[sqlx(default)]
1684    pub count: i64,
1685    #[sqlx(default)]
1686    pub linux_count: i64,
1687    #[sqlx(default)]
1688    pub mac_count: i64,
1689    #[sqlx(default)]
1690    pub windows_count: i64,
1691}
1692
1693#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1694pub struct Invite {
1695    pub email_address: String,
1696    pub email_confirmation_code: String,
1697}
1698
1699#[derive(Debug, Serialize, Deserialize)]
1700pub struct NewUserParams {
1701    pub github_login: String,
1702    pub github_user_id: i32,
1703    pub invite_count: i32,
1704}
1705
1706fn random_invite_code() -> String {
1707    nanoid::nanoid!(16)
1708}
1709
1710fn random_email_confirmation_code() -> String {
1711    nanoid::nanoid!(64)
1712}
1713
1714#[cfg(test)]
1715pub use test::*;
1716
1717#[cfg(test)]
1718mod test {
1719    use super::*;
1720    use anyhow::anyhow;
1721    use collections::BTreeMap;
1722    use gpui::executor::Background;
1723    use lazy_static::lazy_static;
1724    use parking_lot::Mutex;
1725    use rand::prelude::*;
1726    use sqlx::{
1727        migrate::{MigrateDatabase, Migrator},
1728        Postgres,
1729    };
1730    use std::{path::Path, sync::Arc};
1731    use util::post_inc;
1732
1733    pub struct FakeDb {
1734        background: Arc<Background>,
1735        pub users: Mutex<BTreeMap<UserId, User>>,
1736        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1737        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1738        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1739        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1740        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1741        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1742        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1743        pub contacts: Mutex<Vec<FakeContact>>,
1744        next_channel_message_id: Mutex<i32>,
1745        next_user_id: Mutex<i32>,
1746        next_org_id: Mutex<i32>,
1747        next_channel_id: Mutex<i32>,
1748        next_project_id: Mutex<i32>,
1749    }
1750
1751    #[derive(Debug)]
1752    pub struct FakeContact {
1753        pub requester_id: UserId,
1754        pub responder_id: UserId,
1755        pub accepted: bool,
1756        pub should_notify: bool,
1757    }
1758
1759    impl FakeDb {
1760        pub fn new(background: Arc<Background>) -> Self {
1761            Self {
1762                background,
1763                users: Default::default(),
1764                next_user_id: Mutex::new(0),
1765                projects: Default::default(),
1766                worktree_extensions: Default::default(),
1767                next_project_id: Mutex::new(1),
1768                orgs: Default::default(),
1769                next_org_id: Mutex::new(1),
1770                org_memberships: Default::default(),
1771                channels: Default::default(),
1772                next_channel_id: Mutex::new(1),
1773                channel_memberships: Default::default(),
1774                channel_messages: Default::default(),
1775                next_channel_message_id: Mutex::new(1),
1776                contacts: Default::default(),
1777            }
1778        }
1779    }
1780
1781    #[async_trait]
1782    impl Db for FakeDb {
1783        async fn create_user(
1784            &self,
1785            email_address: &str,
1786            admin: bool,
1787            params: NewUserParams,
1788        ) -> Result<UserId> {
1789            self.background.simulate_random_delay().await;
1790
1791            let mut users = self.users.lock();
1792            if let Some(user) = users
1793                .values()
1794                .find(|user| user.github_login == params.github_login)
1795            {
1796                Ok(user.id)
1797            } else {
1798                let id = post_inc(&mut *self.next_user_id.lock());
1799                let user_id = UserId(id);
1800                users.insert(
1801                    user_id,
1802                    User {
1803                        id: user_id,
1804                        github_login: params.github_login,
1805                        github_user_id: Some(params.github_user_id),
1806                        email_address: Some(email_address.to_string()),
1807                        admin,
1808                        invite_code: None,
1809                        invite_count: 0,
1810                        connected_once: false,
1811                    },
1812                );
1813                Ok(user_id)
1814            }
1815        }
1816
1817        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1818            unimplemented!()
1819        }
1820
1821        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1822            unimplemented!()
1823        }
1824
1825        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1826            self.background.simulate_random_delay().await;
1827            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1828        }
1829
1830        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1831            self.background.simulate_random_delay().await;
1832            let users = self.users.lock();
1833            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1834        }
1835
1836        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1837            unimplemented!()
1838        }
1839
1840        async fn get_user_by_github_account(
1841            &self,
1842            github_login: &str,
1843            github_user_id: Option<i32>,
1844        ) -> Result<Option<User>> {
1845            self.background.simulate_random_delay().await;
1846            if let Some(github_user_id) = github_user_id {
1847                for user in self.users.lock().values_mut() {
1848                    if user.github_user_id == Some(github_user_id) {
1849                        user.github_login = github_login.into();
1850                        return Ok(Some(user.clone()));
1851                    }
1852                    if user.github_login == github_login {
1853                        user.github_user_id = Some(github_user_id);
1854                        return Ok(Some(user.clone()));
1855                    }
1856                }
1857                Ok(None)
1858            } else {
1859                Ok(self
1860                    .users
1861                    .lock()
1862                    .values()
1863                    .find(|user| user.github_login == github_login)
1864                    .cloned())
1865            }
1866        }
1867
1868        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1869            unimplemented!()
1870        }
1871
1872        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1873            self.background.simulate_random_delay().await;
1874            let mut users = self.users.lock();
1875            let mut user = users
1876                .get_mut(&id)
1877                .ok_or_else(|| anyhow!("user not found"))?;
1878            user.connected_once = connected_once;
1879            Ok(())
1880        }
1881
1882        async fn destroy_user(&self, _id: UserId) -> Result<()> {
1883            unimplemented!()
1884        }
1885
1886        // signups
1887
1888        async fn create_signup(&self, _signup: Signup) -> Result<()> {
1889            unimplemented!()
1890        }
1891
1892        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1893            unimplemented!()
1894        }
1895
1896        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1897            unimplemented!()
1898        }
1899
1900        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1901            unimplemented!()
1902        }
1903
1904        async fn create_user_from_invite(
1905            &self,
1906            _invite: &Invite,
1907            _user: NewUserParams,
1908        ) -> Result<(UserId, Option<UserId>, String)> {
1909            unimplemented!()
1910        }
1911
1912        // invite codes
1913
1914        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1915            unimplemented!()
1916        }
1917
1918        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1919            self.background.simulate_random_delay().await;
1920            Ok(None)
1921        }
1922
1923        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1924            unimplemented!()
1925        }
1926
1927        async fn create_invite_from_code(
1928            &self,
1929            _code: &str,
1930            _email_address: &str,
1931        ) -> Result<Invite> {
1932            unimplemented!()
1933        }
1934
1935        // projects
1936
1937        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
1938            self.background.simulate_random_delay().await;
1939            if !self.users.lock().contains_key(&host_user_id) {
1940                Err(anyhow!("no such user"))?;
1941            }
1942
1943            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
1944            self.projects.lock().insert(
1945                project_id,
1946                Project {
1947                    id: project_id,
1948                    host_user_id,
1949                    unregistered: false,
1950                },
1951            );
1952            Ok(project_id)
1953        }
1954
1955        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
1956            self.background.simulate_random_delay().await;
1957            self.projects
1958                .lock()
1959                .get_mut(&project_id)
1960                .ok_or_else(|| anyhow!("no such project"))?
1961                .unregistered = true;
1962            Ok(())
1963        }
1964
1965        async fn update_worktree_extensions(
1966            &self,
1967            project_id: ProjectId,
1968            worktree_id: u64,
1969            extensions: HashMap<String, u32>,
1970        ) -> Result<()> {
1971            self.background.simulate_random_delay().await;
1972            if !self.projects.lock().contains_key(&project_id) {
1973                Err(anyhow!("no such project"))?;
1974            }
1975
1976            for (extension, count) in extensions {
1977                self.worktree_extensions
1978                    .lock()
1979                    .insert((project_id, worktree_id, extension), count);
1980            }
1981
1982            Ok(())
1983        }
1984
1985        async fn get_project_extensions(
1986            &self,
1987            _project_id: ProjectId,
1988        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
1989            unimplemented!()
1990        }
1991
1992        async fn record_user_activity(
1993            &self,
1994            _time_period: Range<OffsetDateTime>,
1995            _active_projects: &[(UserId, ProjectId)],
1996        ) -> Result<()> {
1997            unimplemented!()
1998        }
1999
2000        async fn get_active_user_count(
2001            &self,
2002            _time_period: Range<OffsetDateTime>,
2003            _min_duration: Duration,
2004            _only_collaborative: bool,
2005        ) -> Result<usize> {
2006            unimplemented!()
2007        }
2008
2009        async fn get_top_users_activity_summary(
2010            &self,
2011            _time_period: Range<OffsetDateTime>,
2012            _limit: usize,
2013        ) -> Result<Vec<UserActivitySummary>> {
2014            unimplemented!()
2015        }
2016
2017        async fn get_user_activity_timeline(
2018            &self,
2019            _time_period: Range<OffsetDateTime>,
2020            _user_id: UserId,
2021        ) -> Result<Vec<UserActivityPeriod>> {
2022            unimplemented!()
2023        }
2024
2025        // contacts
2026
2027        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2028            self.background.simulate_random_delay().await;
2029            let mut contacts = vec![Contact::Accepted {
2030                user_id: id,
2031                should_notify: false,
2032            }];
2033
2034            for contact in self.contacts.lock().iter() {
2035                if contact.requester_id == id {
2036                    if contact.accepted {
2037                        contacts.push(Contact::Accepted {
2038                            user_id: contact.responder_id,
2039                            should_notify: contact.should_notify,
2040                        });
2041                    } else {
2042                        contacts.push(Contact::Outgoing {
2043                            user_id: contact.responder_id,
2044                        });
2045                    }
2046                } else if contact.responder_id == id {
2047                    if contact.accepted {
2048                        contacts.push(Contact::Accepted {
2049                            user_id: contact.requester_id,
2050                            should_notify: false,
2051                        });
2052                    } else {
2053                        contacts.push(Contact::Incoming {
2054                            user_id: contact.requester_id,
2055                            should_notify: contact.should_notify,
2056                        });
2057                    }
2058                }
2059            }
2060
2061            contacts.sort_unstable_by_key(|contact| contact.user_id());
2062            Ok(contacts)
2063        }
2064
2065        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2066            self.background.simulate_random_delay().await;
2067            Ok(self.contacts.lock().iter().any(|contact| {
2068                contact.accepted
2069                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2070                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2071            }))
2072        }
2073
2074        async fn send_contact_request(
2075            &self,
2076            requester_id: UserId,
2077            responder_id: UserId,
2078        ) -> Result<()> {
2079            self.background.simulate_random_delay().await;
2080            let mut contacts = self.contacts.lock();
2081            for contact in contacts.iter_mut() {
2082                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2083                    if contact.accepted {
2084                        Err(anyhow!("contact already exists"))?;
2085                    } else {
2086                        Err(anyhow!("contact already requested"))?;
2087                    }
2088                }
2089                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2090                    if contact.accepted {
2091                        Err(anyhow!("contact already exists"))?;
2092                    } else {
2093                        contact.accepted = true;
2094                        contact.should_notify = false;
2095                        return Ok(());
2096                    }
2097                }
2098            }
2099            contacts.push(FakeContact {
2100                requester_id,
2101                responder_id,
2102                accepted: false,
2103                should_notify: true,
2104            });
2105            Ok(())
2106        }
2107
2108        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2109            self.background.simulate_random_delay().await;
2110            self.contacts.lock().retain(|contact| {
2111                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2112            });
2113            Ok(())
2114        }
2115
2116        async fn dismiss_contact_notification(
2117            &self,
2118            user_id: UserId,
2119            contact_user_id: UserId,
2120        ) -> Result<()> {
2121            self.background.simulate_random_delay().await;
2122            let mut contacts = self.contacts.lock();
2123            for contact in contacts.iter_mut() {
2124                if contact.requester_id == contact_user_id
2125                    && contact.responder_id == user_id
2126                    && !contact.accepted
2127                {
2128                    contact.should_notify = false;
2129                    return Ok(());
2130                }
2131                if contact.requester_id == user_id
2132                    && contact.responder_id == contact_user_id
2133                    && contact.accepted
2134                {
2135                    contact.should_notify = false;
2136                    return Ok(());
2137                }
2138            }
2139            Err(anyhow!("no such notification"))?
2140        }
2141
2142        async fn respond_to_contact_request(
2143            &self,
2144            responder_id: UserId,
2145            requester_id: UserId,
2146            accept: bool,
2147        ) -> Result<()> {
2148            self.background.simulate_random_delay().await;
2149            let mut contacts = self.contacts.lock();
2150            for (ix, contact) in contacts.iter_mut().enumerate() {
2151                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2152                    if contact.accepted {
2153                        Err(anyhow!("contact already confirmed"))?;
2154                    }
2155                    if accept {
2156                        contact.accepted = true;
2157                        contact.should_notify = true;
2158                    } else {
2159                        contacts.remove(ix);
2160                    }
2161                    return Ok(());
2162                }
2163            }
2164            Err(anyhow!("no such contact request"))?
2165        }
2166
2167        async fn create_access_token_hash(
2168            &self,
2169            _user_id: UserId,
2170            _access_token_hash: &str,
2171            _max_access_token_count: usize,
2172        ) -> Result<()> {
2173            unimplemented!()
2174        }
2175
2176        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2177            unimplemented!()
2178        }
2179
2180        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2181            unimplemented!()
2182        }
2183
2184        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2185            self.background.simulate_random_delay().await;
2186            let mut orgs = self.orgs.lock();
2187            if orgs.values().any(|org| org.slug == slug) {
2188                Err(anyhow!("org already exists"))?
2189            } else {
2190                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2191                orgs.insert(
2192                    org_id,
2193                    Org {
2194                        id: org_id,
2195                        name: name.to_string(),
2196                        slug: slug.to_string(),
2197                    },
2198                );
2199                Ok(org_id)
2200            }
2201        }
2202
2203        async fn add_org_member(
2204            &self,
2205            org_id: OrgId,
2206            user_id: UserId,
2207            is_admin: bool,
2208        ) -> Result<()> {
2209            self.background.simulate_random_delay().await;
2210            if !self.orgs.lock().contains_key(&org_id) {
2211                Err(anyhow!("org does not exist"))?;
2212            }
2213            if !self.users.lock().contains_key(&user_id) {
2214                Err(anyhow!("user does not exist"))?;
2215            }
2216
2217            self.org_memberships
2218                .lock()
2219                .entry((org_id, user_id))
2220                .or_insert(is_admin);
2221            Ok(())
2222        }
2223
2224        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2225            self.background.simulate_random_delay().await;
2226            if !self.orgs.lock().contains_key(&org_id) {
2227                Err(anyhow!("org does not exist"))?;
2228            }
2229
2230            let mut channels = self.channels.lock();
2231            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2232            channels.insert(
2233                channel_id,
2234                Channel {
2235                    id: channel_id,
2236                    name: name.to_string(),
2237                    owner_id: org_id.0,
2238                    owner_is_user: false,
2239                },
2240            );
2241            Ok(channel_id)
2242        }
2243
2244        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2245            self.background.simulate_random_delay().await;
2246            Ok(self
2247                .channels
2248                .lock()
2249                .values()
2250                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2251                .cloned()
2252                .collect())
2253        }
2254
2255        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2256            self.background.simulate_random_delay().await;
2257            let channels = self.channels.lock();
2258            let memberships = self.channel_memberships.lock();
2259            Ok(channels
2260                .values()
2261                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2262                .cloned()
2263                .collect())
2264        }
2265
2266        async fn can_user_access_channel(
2267            &self,
2268            user_id: UserId,
2269            channel_id: ChannelId,
2270        ) -> Result<bool> {
2271            self.background.simulate_random_delay().await;
2272            Ok(self
2273                .channel_memberships
2274                .lock()
2275                .contains_key(&(channel_id, user_id)))
2276        }
2277
2278        async fn add_channel_member(
2279            &self,
2280            channel_id: ChannelId,
2281            user_id: UserId,
2282            is_admin: bool,
2283        ) -> Result<()> {
2284            self.background.simulate_random_delay().await;
2285            if !self.channels.lock().contains_key(&channel_id) {
2286                Err(anyhow!("channel does not exist"))?;
2287            }
2288            if !self.users.lock().contains_key(&user_id) {
2289                Err(anyhow!("user does not exist"))?;
2290            }
2291
2292            self.channel_memberships
2293                .lock()
2294                .entry((channel_id, user_id))
2295                .or_insert(is_admin);
2296            Ok(())
2297        }
2298
2299        async fn create_channel_message(
2300            &self,
2301            channel_id: ChannelId,
2302            sender_id: UserId,
2303            body: &str,
2304            timestamp: OffsetDateTime,
2305            nonce: u128,
2306        ) -> Result<MessageId> {
2307            self.background.simulate_random_delay().await;
2308            if !self.channels.lock().contains_key(&channel_id) {
2309                Err(anyhow!("channel does not exist"))?;
2310            }
2311            if !self.users.lock().contains_key(&sender_id) {
2312                Err(anyhow!("user does not exist"))?;
2313            }
2314
2315            let mut messages = self.channel_messages.lock();
2316            if let Some(message) = messages
2317                .values()
2318                .find(|message| message.nonce.as_u128() == nonce)
2319            {
2320                Ok(message.id)
2321            } else {
2322                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2323                messages.insert(
2324                    message_id,
2325                    ChannelMessage {
2326                        id: message_id,
2327                        channel_id,
2328                        sender_id,
2329                        body: body.to_string(),
2330                        sent_at: timestamp,
2331                        nonce: Uuid::from_u128(nonce),
2332                    },
2333                );
2334                Ok(message_id)
2335            }
2336        }
2337
2338        async fn get_channel_messages(
2339            &self,
2340            channel_id: ChannelId,
2341            count: usize,
2342            before_id: Option<MessageId>,
2343        ) -> Result<Vec<ChannelMessage>> {
2344            self.background.simulate_random_delay().await;
2345            let mut messages = self
2346                .channel_messages
2347                .lock()
2348                .values()
2349                .rev()
2350                .filter(|message| {
2351                    message.channel_id == channel_id
2352                        && message.id < before_id.unwrap_or(MessageId::MAX)
2353                })
2354                .take(count)
2355                .cloned()
2356                .collect::<Vec<_>>();
2357            messages.sort_unstable_by_key(|message| message.id);
2358            Ok(messages)
2359        }
2360
2361        async fn teardown(&self, _: &str) {}
2362
2363        #[cfg(test)]
2364        fn as_fake(&self) -> Option<&FakeDb> {
2365            Some(self)
2366        }
2367    }
2368
2369    pub struct TestDb {
2370        pub db: Option<Arc<dyn Db>>,
2371        pub url: String,
2372    }
2373
2374    impl TestDb {
2375        #[allow(clippy::await_holding_lock)]
2376        pub async fn postgres() -> Self {
2377            lazy_static! {
2378                static ref LOCK: Mutex<()> = Mutex::new(());
2379            }
2380
2381            let _guard = LOCK.lock();
2382            let mut rng = StdRng::from_entropy();
2383            let name = format!("zed-test-{}", rng.gen::<u128>());
2384            let url = format!("postgres://postgres@localhost/{}", name);
2385            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2386            Postgres::create_database(&url)
2387                .await
2388                .expect("failed to create test db");
2389            let db = PostgresDb::new(&url, 5).await.unwrap();
2390            let migrator = Migrator::new(migrations_path).await.unwrap();
2391            migrator.run(&db.pool).await.unwrap();
2392            Self {
2393                db: Some(Arc::new(db)),
2394                url,
2395            }
2396        }
2397
2398        pub fn fake(background: Arc<Background>) -> Self {
2399            Self {
2400                db: Some(Arc::new(FakeDb::new(background))),
2401                url: Default::default(),
2402            }
2403        }
2404
2405        pub fn db(&self) -> &Arc<dyn Db> {
2406            self.db.as_ref().unwrap()
2407        }
2408    }
2409
2410    impl Drop for TestDb {
2411        fn drop(&mut self) {
2412            if let Some(db) = self.db.take() {
2413                futures::executor::block_on(db.teardown(&self.url));
2414            }
2415        }
2416    }
2417}