db.rs

   1use std::{cmp, ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use nanoid::nanoid;
  10use serde::{Deserialize, Serialize};
  11pub use sqlx::postgres::PgPoolOptions as DbOptions;
  12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[async_trait]
  16pub trait Db: Send + Sync {
  17    async fn create_user(
  18        &self,
  19        github_login: &str,
  20        email_address: Option<&str>,
  21        admin: bool,
  22    ) -> Result<UserId>;
  23    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  24    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  28    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  29    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  30    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  31    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33
  34    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  35    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  36    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  37    async fn redeem_invite_code(
  38        &self,
  39        code: &str,
  40        login: &str,
  41        email_address: Option<&str>,
  42    ) -> Result<UserId>;
  43
  44    /// Registers a new project for the given user.
  45    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  46
  47    /// Unregisters a project for the given project id.
  48    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  49
  50    /// Update file counts by extension for the given project and worktree.
  51    async fn update_worktree_extensions(
  52        &self,
  53        project_id: ProjectId,
  54        worktree_id: u64,
  55        extensions: HashMap<String, u32>,
  56    ) -> Result<()>;
  57
  58    /// Get the file counts on the given project keyed by their worktree and extension.
  59    async fn get_project_extensions(
  60        &self,
  61        project_id: ProjectId,
  62    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  63
  64    /// Record which users have been active in which projects during
  65    /// a given period of time.
  66    async fn record_user_activity(
  67        &self,
  68        time_period: Range<OffsetDateTime>,
  69        active_projects: &[(UserId, ProjectId)],
  70    ) -> Result<()>;
  71
  72    /// Get the number of users who have been active in the given
  73    /// time period for at least the given time duration.
  74    async fn get_active_user_count(
  75        &self,
  76        time_period: Range<OffsetDateTime>,
  77        min_duration: Duration,
  78        only_collaborative: bool,
  79    ) -> Result<usize>;
  80
  81    /// Get the users that have been most active during the given time period,
  82    /// along with the amount of time they have been active in each project.
  83    async fn get_top_users_activity_summary(
  84        &self,
  85        time_period: Range<OffsetDateTime>,
  86        max_user_count: usize,
  87    ) -> Result<Vec<UserActivitySummary>>;
  88
  89    /// Get the project activity for the given user and time period.
  90    async fn get_user_activity_timeline(
  91        &self,
  92        time_period: Range<OffsetDateTime>,
  93        user_id: UserId,
  94    ) -> Result<Vec<UserActivityPeriod>>;
  95
  96    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  97    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  98    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  99    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
 100    async fn dismiss_contact_notification(
 101        &self,
 102        responder_id: UserId,
 103        requester_id: UserId,
 104    ) -> Result<()>;
 105    async fn respond_to_contact_request(
 106        &self,
 107        responder_id: UserId,
 108        requester_id: UserId,
 109        accept: bool,
 110    ) -> Result<()>;
 111
 112    async fn create_access_token_hash(
 113        &self,
 114        user_id: UserId,
 115        access_token_hash: &str,
 116        max_access_token_count: usize,
 117    ) -> Result<()>;
 118    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 119    #[cfg(any(test, feature = "seed-support"))]
 120
 121    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 122    #[cfg(any(test, feature = "seed-support"))]
 123    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 124    #[cfg(any(test, feature = "seed-support"))]
 125    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 126    #[cfg(any(test, feature = "seed-support"))]
 127    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 128    #[cfg(any(test, feature = "seed-support"))]
 129
 130    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 131    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 132    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 133        -> Result<bool>;
 134    #[cfg(any(test, feature = "seed-support"))]
 135    async fn add_channel_member(
 136        &self,
 137        channel_id: ChannelId,
 138        user_id: UserId,
 139        is_admin: bool,
 140    ) -> Result<()>;
 141    async fn create_channel_message(
 142        &self,
 143        channel_id: ChannelId,
 144        sender_id: UserId,
 145        body: &str,
 146        timestamp: OffsetDateTime,
 147        nonce: u128,
 148    ) -> Result<MessageId>;
 149    async fn get_channel_messages(
 150        &self,
 151        channel_id: ChannelId,
 152        count: usize,
 153        before_id: Option<MessageId>,
 154    ) -> Result<Vec<ChannelMessage>>;
 155    #[cfg(test)]
 156    async fn teardown(&self, url: &str);
 157    #[cfg(test)]
 158    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 159}
 160
 161pub struct PostgresDb {
 162    pool: sqlx::PgPool,
 163}
 164
 165impl PostgresDb {
 166    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 167        let pool = DbOptions::new()
 168            .max_connections(max_connections)
 169            .connect(&url)
 170            .await
 171            .context("failed to connect to postgres database")?;
 172        Ok(Self { pool })
 173    }
 174}
 175
 176#[async_trait]
 177impl Db for PostgresDb {
 178    // users
 179
 180    async fn create_user(
 181        &self,
 182        github_login: &str,
 183        email_address: Option<&str>,
 184        admin: bool,
 185    ) -> Result<UserId> {
 186        let query = "
 187            INSERT INTO users (github_login, email_address, admin)
 188            VALUES ($1, $2, $3)
 189            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 190            RETURNING id
 191        ";
 192        Ok(sqlx::query_scalar(query)
 193            .bind(github_login)
 194            .bind(email_address)
 195            .bind(admin)
 196            .fetch_one(&self.pool)
 197            .await
 198            .map(UserId)?)
 199    }
 200
 201    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 202        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 203        Ok(sqlx::query_as(query)
 204            .bind(limit as i32)
 205            .bind((page * limit) as i32)
 206            .fetch_all(&self.pool)
 207            .await?)
 208    }
 209
 210    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 211        let mut query = QueryBuilder::new(
 212            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 213        );
 214        query.push_values(
 215            users,
 216            |mut query, (github_login, email_address, invite_count)| {
 217                query
 218                    .push_bind(github_login)
 219                    .push_bind(email_address)
 220                    .push_bind(false)
 221                    .push_bind(nanoid!(16))
 222                    .push_bind(invite_count as i32);
 223            },
 224        );
 225        query.push(
 226            "
 227            ON CONFLICT (github_login) DO UPDATE SET
 228                github_login = excluded.github_login,
 229                invite_count = excluded.invite_count,
 230                invite_code = CASE WHEN users.invite_code IS NULL
 231                                   THEN excluded.invite_code
 232                                   ELSE users.invite_code
 233                              END
 234            RETURNING id
 235            ",
 236        );
 237
 238        let rows = query.build().fetch_all(&self.pool).await?;
 239        Ok(rows
 240            .into_iter()
 241            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 242            .collect())
 243    }
 244
 245    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 246        let like_string = fuzzy_like_string(name_query);
 247        let query = "
 248            SELECT users.*
 249            FROM users
 250            WHERE github_login ILIKE $1
 251            ORDER BY github_login <-> $2
 252            LIMIT $3
 253        ";
 254        Ok(sqlx::query_as(query)
 255            .bind(like_string)
 256            .bind(name_query)
 257            .bind(limit as i32)
 258            .fetch_all(&self.pool)
 259            .await?)
 260    }
 261
 262    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 263        let users = self.get_users_by_ids(vec![id]).await?;
 264        Ok(users.into_iter().next())
 265    }
 266
 267    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 268        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 269        let query = "
 270            SELECT users.*
 271            FROM users
 272            WHERE users.id = ANY ($1)
 273        ";
 274        Ok(sqlx::query_as(query)
 275            .bind(&ids)
 276            .fetch_all(&self.pool)
 277            .await?)
 278    }
 279
 280    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 281        let query = format!(
 282            "
 283            SELECT users.*
 284            FROM users
 285            WHERE invite_count = 0
 286            AND inviter_id IS{} NULL
 287            ",
 288            if invited_by_another_user { " NOT" } else { "" }
 289        );
 290
 291        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 292    }
 293
 294    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 295        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 296        Ok(sqlx::query_as(query)
 297            .bind(github_login)
 298            .fetch_optional(&self.pool)
 299            .await?)
 300    }
 301
 302    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 303        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 304        Ok(sqlx::query(query)
 305            .bind(is_admin)
 306            .bind(id.0)
 307            .execute(&self.pool)
 308            .await
 309            .map(drop)?)
 310    }
 311
 312    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 313        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 314        Ok(sqlx::query(query)
 315            .bind(connected_once)
 316            .bind(id.0)
 317            .execute(&self.pool)
 318            .await
 319            .map(drop)?)
 320    }
 321
 322    async fn destroy_user(&self, id: UserId) -> Result<()> {
 323        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 324        sqlx::query(query)
 325            .bind(id.0)
 326            .execute(&self.pool)
 327            .await
 328            .map(drop)?;
 329        let query = "DELETE FROM users WHERE id = $1;";
 330        Ok(sqlx::query(query)
 331            .bind(id.0)
 332            .execute(&self.pool)
 333            .await
 334            .map(drop)?)
 335    }
 336
 337    // invite codes
 338
 339    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 340        let mut tx = self.pool.begin().await?;
 341        if count > 0 {
 342            sqlx::query(
 343                "
 344                UPDATE users
 345                SET invite_code = $1
 346                WHERE id = $2 AND invite_code IS NULL
 347            ",
 348            )
 349            .bind(nanoid!(16))
 350            .bind(id)
 351            .execute(&mut tx)
 352            .await?;
 353        }
 354
 355        sqlx::query(
 356            "
 357            UPDATE users
 358            SET invite_count = $1
 359            WHERE id = $2
 360            ",
 361        )
 362        .bind(count as i32)
 363        .bind(id)
 364        .execute(&mut tx)
 365        .await?;
 366        tx.commit().await?;
 367        Ok(())
 368    }
 369
 370    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 371        let result: Option<(String, i32)> = sqlx::query_as(
 372            "
 373                SELECT invite_code, invite_count
 374                FROM users
 375                WHERE id = $1 AND invite_code IS NOT NULL 
 376            ",
 377        )
 378        .bind(id)
 379        .fetch_optional(&self.pool)
 380        .await?;
 381        if let Some((code, count)) = result {
 382            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 383        } else {
 384            Ok(None)
 385        }
 386    }
 387
 388    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 389        sqlx::query_as(
 390            "
 391                SELECT *
 392                FROM users
 393                WHERE invite_code = $1
 394            ",
 395        )
 396        .bind(code)
 397        .fetch_optional(&self.pool)
 398        .await?
 399        .ok_or_else(|| {
 400            Error::Http(
 401                StatusCode::NOT_FOUND,
 402                "that invite code does not exist".to_string(),
 403            )
 404        })
 405    }
 406
 407    async fn redeem_invite_code(
 408        &self,
 409        code: &str,
 410        login: &str,
 411        email_address: Option<&str>,
 412    ) -> Result<UserId> {
 413        let mut tx = self.pool.begin().await?;
 414
 415        let inviter_id: Option<UserId> = sqlx::query_scalar(
 416            "
 417                UPDATE users
 418                SET invite_count = invite_count - 1
 419                WHERE
 420                    invite_code = $1 AND
 421                    invite_count > 0
 422                RETURNING id
 423            ",
 424        )
 425        .bind(code)
 426        .fetch_optional(&mut tx)
 427        .await?;
 428
 429        let inviter_id = match inviter_id {
 430            Some(inviter_id) => inviter_id,
 431            None => {
 432                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 433                    .bind(code)
 434                    .fetch_optional(&mut tx)
 435                    .await?
 436                    .is_some()
 437                {
 438                    Err(Error::Http(
 439                        StatusCode::UNAUTHORIZED,
 440                        "no invites remaining".to_string(),
 441                    ))?
 442                } else {
 443                    Err(Error::Http(
 444                        StatusCode::NOT_FOUND,
 445                        "invite code not found".to_string(),
 446                    ))?
 447                }
 448            }
 449        };
 450
 451        let invitee_id = sqlx::query_scalar(
 452            "
 453                INSERT INTO users
 454                    (github_login, email_address, admin, inviter_id)
 455                VALUES
 456                    ($1, $2, 'f', $3)
 457                RETURNING id
 458            ",
 459        )
 460        .bind(login)
 461        .bind(email_address)
 462        .bind(inviter_id)
 463        .fetch_one(&mut tx)
 464        .await
 465        .map(UserId)?;
 466
 467        sqlx::query(
 468            "
 469                INSERT INTO contacts
 470                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 471                VALUES
 472                    ($1, $2, 't', 't', 't')
 473            ",
 474        )
 475        .bind(inviter_id)
 476        .bind(invitee_id)
 477        .execute(&mut tx)
 478        .await?;
 479
 480        tx.commit().await?;
 481        Ok(invitee_id)
 482    }
 483
 484    // projects
 485
 486    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 487        Ok(sqlx::query_scalar(
 488            "
 489            INSERT INTO projects(host_user_id)
 490            VALUES ($1)
 491            RETURNING id
 492            ",
 493        )
 494        .bind(host_user_id)
 495        .fetch_one(&self.pool)
 496        .await
 497        .map(ProjectId)?)
 498    }
 499
 500    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 501        sqlx::query(
 502            "
 503            UPDATE projects
 504            SET unregistered = 't'
 505            WHERE id = $1
 506            ",
 507        )
 508        .bind(project_id)
 509        .execute(&self.pool)
 510        .await?;
 511        Ok(())
 512    }
 513
 514    async fn update_worktree_extensions(
 515        &self,
 516        project_id: ProjectId,
 517        worktree_id: u64,
 518        extensions: HashMap<String, u32>,
 519    ) -> Result<()> {
 520        if extensions.is_empty() {
 521            return Ok(());
 522        }
 523
 524        let mut query = QueryBuilder::new(
 525            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 526        );
 527        query.push_values(extensions, |mut query, (extension, count)| {
 528            query
 529                .push_bind(project_id)
 530                .push_bind(worktree_id as i32)
 531                .push_bind(extension)
 532                .push_bind(count as i32);
 533        });
 534        query.push(
 535            "
 536            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 537            count = excluded.count
 538            ",
 539        );
 540        query.build().execute(&self.pool).await?;
 541
 542        Ok(())
 543    }
 544
 545    async fn get_project_extensions(
 546        &self,
 547        project_id: ProjectId,
 548    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 549        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 550        struct WorktreeExtension {
 551            worktree_id: i32,
 552            extension: String,
 553            count: i32,
 554        }
 555
 556        let query = "
 557            SELECT worktree_id, extension, count
 558            FROM worktree_extensions
 559            WHERE project_id = $1
 560        ";
 561        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 562            .bind(&project_id)
 563            .fetch_all(&self.pool)
 564            .await?;
 565
 566        let mut extension_counts = HashMap::default();
 567        for count in counts {
 568            extension_counts
 569                .entry(count.worktree_id as u64)
 570                .or_insert(HashMap::default())
 571                .insert(count.extension, count.count as usize);
 572        }
 573        Ok(extension_counts)
 574    }
 575
 576    async fn record_user_activity(
 577        &self,
 578        time_period: Range<OffsetDateTime>,
 579        projects: &[(UserId, ProjectId)],
 580    ) -> Result<()> {
 581        let query = "
 582            INSERT INTO project_activity_periods
 583            (ended_at, duration_millis, user_id, project_id)
 584            VALUES
 585            ($1, $2, $3, $4);
 586        ";
 587
 588        let mut tx = self.pool.begin().await?;
 589        let duration_millis =
 590            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 591        for (user_id, project_id) in projects {
 592            sqlx::query(query)
 593                .bind(time_period.end)
 594                .bind(duration_millis)
 595                .bind(user_id)
 596                .bind(project_id)
 597                .execute(&mut tx)
 598                .await?;
 599        }
 600        tx.commit().await?;
 601
 602        Ok(())
 603    }
 604
 605    async fn get_active_user_count(
 606        &self,
 607        time_period: Range<OffsetDateTime>,
 608        min_duration: Duration,
 609        only_collaborative: bool,
 610    ) -> Result<usize> {
 611        let mut with_clause = String::new();
 612        with_clause.push_str("WITH\n");
 613        with_clause.push_str(
 614            "
 615            project_durations AS (
 616                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 617                FROM project_activity_periods
 618                WHERE $1 < ended_at AND ended_at <= $2
 619                GROUP BY user_id, project_id
 620            ),
 621            ",
 622        );
 623        with_clause.push_str(
 624            "
 625            project_collaborators as (
 626                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 627                FROM project_durations
 628                GROUP BY project_id
 629            ),
 630            ",
 631        );
 632
 633        if only_collaborative {
 634            with_clause.push_str(
 635                "
 636                user_durations AS (
 637                    SELECT user_id, SUM(project_duration) as total_duration
 638                    FROM project_durations, project_collaborators
 639                    WHERE
 640                        project_durations.project_id = project_collaborators.project_id AND
 641                        max_collaborators > 1
 642                    GROUP BY user_id
 643                    ORDER BY total_duration DESC
 644                    LIMIT $3
 645                )
 646                ",
 647            );
 648        } else {
 649            with_clause.push_str(
 650                "
 651                user_durations AS (
 652                    SELECT user_id, SUM(project_duration) as total_duration
 653                    FROM project_durations
 654                    GROUP BY user_id
 655                    ORDER BY total_duration DESC
 656                    LIMIT $3
 657                )
 658                ",
 659            );
 660        }
 661
 662        let query = format!(
 663            "
 664            {with_clause}
 665            SELECT count(user_durations.user_id)
 666            FROM user_durations
 667            WHERE user_durations.total_duration >= $3
 668            "
 669        );
 670
 671        let count: i64 = sqlx::query_scalar(&query)
 672            .bind(time_period.start)
 673            .bind(time_period.end)
 674            .bind(min_duration.as_millis() as i64)
 675            .fetch_one(&self.pool)
 676            .await?;
 677        Ok(count as usize)
 678    }
 679
 680    async fn get_top_users_activity_summary(
 681        &self,
 682        time_period: Range<OffsetDateTime>,
 683        max_user_count: usize,
 684    ) -> Result<Vec<UserActivitySummary>> {
 685        let query = "
 686            WITH
 687                project_durations AS (
 688                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 689                    FROM project_activity_periods
 690                    WHERE $1 < ended_at AND ended_at <= $2
 691                    GROUP BY user_id, project_id
 692                ),
 693                user_durations AS (
 694                    SELECT user_id, SUM(project_duration) as total_duration
 695                    FROM project_durations
 696                    GROUP BY user_id
 697                    ORDER BY total_duration DESC
 698                    LIMIT $3
 699                ),
 700                project_collaborators as (
 701                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 702                    FROM project_durations
 703                    GROUP BY project_id
 704                )
 705            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
 706            FROM user_durations, project_durations, project_collaborators, users
 707            WHERE
 708                user_durations.user_id = project_durations.user_id AND
 709                user_durations.user_id = users.id AND
 710                project_durations.project_id = project_collaborators.project_id
 711            ORDER BY total_duration DESC, user_id ASC
 712        ";
 713
 714        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
 715            .bind(time_period.start)
 716            .bind(time_period.end)
 717            .bind(max_user_count as i32)
 718            .fetch(&self.pool);
 719
 720        let mut result = Vec::<UserActivitySummary>::new();
 721        while let Some(row) = rows.next().await {
 722            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
 723            let project_id = project_id;
 724            let duration = Duration::from_millis(duration_millis as u64);
 725            let project_activity = ProjectActivitySummary {
 726                id: project_id,
 727                duration,
 728                max_collaborators: project_collaborators as usize,
 729            };
 730            if let Some(last_summary) = result.last_mut() {
 731                if last_summary.id == user_id {
 732                    last_summary.project_activity.push(project_activity);
 733                    continue;
 734                }
 735            }
 736            result.push(UserActivitySummary {
 737                id: user_id,
 738                project_activity: vec![project_activity],
 739                github_login,
 740            });
 741        }
 742
 743        Ok(result)
 744    }
 745
 746    async fn get_user_activity_timeline(
 747        &self,
 748        time_period: Range<OffsetDateTime>,
 749        user_id: UserId,
 750    ) -> Result<Vec<UserActivityPeriod>> {
 751        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 752
 753        let query = "
 754            SELECT
 755                project_activity_periods.ended_at,
 756                project_activity_periods.duration_millis,
 757                project_activity_periods.project_id,
 758                worktree_extensions.extension,
 759                worktree_extensions.count
 760            FROM project_activity_periods
 761            LEFT OUTER JOIN
 762                worktree_extensions
 763            ON
 764                project_activity_periods.project_id = worktree_extensions.project_id
 765            WHERE
 766                project_activity_periods.user_id = $1 AND
 767                $2 < project_activity_periods.ended_at AND
 768                project_activity_periods.ended_at <= $3
 769            ORDER BY project_activity_periods.id ASC
 770        ";
 771
 772        let mut rows = sqlx::query_as::<
 773            _,
 774            (
 775                PrimitiveDateTime,
 776                i32,
 777                ProjectId,
 778                Option<String>,
 779                Option<i32>,
 780            ),
 781        >(query)
 782        .bind(user_id)
 783        .bind(time_period.start)
 784        .bind(time_period.end)
 785        .fetch(&self.pool);
 786
 787        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 788        while let Some(row) = rows.next().await {
 789            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 790            let ended_at = ended_at.assume_utc();
 791            let duration = Duration::from_millis(duration_millis as u64);
 792            let started_at = ended_at - duration;
 793            let project_time_periods = time_periods.entry(project_id).or_default();
 794
 795            if let Some(prev_duration) = project_time_periods.last_mut() {
 796                if started_at <= prev_duration.end + COALESCE_THRESHOLD
 797                    && ended_at >= prev_duration.start
 798                {
 799                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
 800                } else {
 801                    project_time_periods.push(UserActivityPeriod {
 802                        project_id,
 803                        start: started_at,
 804                        end: ended_at,
 805                        extensions: Default::default(),
 806                    });
 807                }
 808            } else {
 809                project_time_periods.push(UserActivityPeriod {
 810                    project_id,
 811                    start: started_at,
 812                    end: ended_at,
 813                    extensions: Default::default(),
 814                });
 815            }
 816
 817            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 818                project_time_periods
 819                    .last_mut()
 820                    .unwrap()
 821                    .extensions
 822                    .insert(extension, extension_count as usize);
 823            }
 824        }
 825
 826        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 827        durations.sort_unstable_by_key(|duration| duration.start);
 828        Ok(durations)
 829    }
 830
 831    // contacts
 832
 833    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 834        let query = "
 835            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 836            FROM contacts
 837            WHERE user_id_a = $1 OR user_id_b = $1;
 838        ";
 839
 840        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 841            .bind(user_id)
 842            .fetch(&self.pool);
 843
 844        let mut contacts = vec![Contact::Accepted {
 845            user_id,
 846            should_notify: false,
 847        }];
 848        while let Some(row) = rows.next().await {
 849            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 850
 851            if user_id_a == user_id {
 852                if accepted {
 853                    contacts.push(Contact::Accepted {
 854                        user_id: user_id_b,
 855                        should_notify: should_notify && a_to_b,
 856                    });
 857                } else if a_to_b {
 858                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 859                } else {
 860                    contacts.push(Contact::Incoming {
 861                        user_id: user_id_b,
 862                        should_notify,
 863                    });
 864                }
 865            } else {
 866                if accepted {
 867                    contacts.push(Contact::Accepted {
 868                        user_id: user_id_a,
 869                        should_notify: should_notify && !a_to_b,
 870                    });
 871                } else if a_to_b {
 872                    contacts.push(Contact::Incoming {
 873                        user_id: user_id_a,
 874                        should_notify,
 875                    });
 876                } else {
 877                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 878                }
 879            }
 880        }
 881
 882        contacts.sort_unstable_by_key(|contact| contact.user_id());
 883
 884        Ok(contacts)
 885    }
 886
 887    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 888        let (id_a, id_b) = if user_id_1 < user_id_2 {
 889            (user_id_1, user_id_2)
 890        } else {
 891            (user_id_2, user_id_1)
 892        };
 893
 894        let query = "
 895            SELECT 1 FROM contacts
 896            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 897            LIMIT 1
 898        ";
 899        Ok(sqlx::query_scalar::<_, i32>(query)
 900            .bind(id_a.0)
 901            .bind(id_b.0)
 902            .fetch_optional(&self.pool)
 903            .await?
 904            .is_some())
 905    }
 906
 907    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 908        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 909            (sender_id, receiver_id, true)
 910        } else {
 911            (receiver_id, sender_id, false)
 912        };
 913        let query = "
 914            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 915            VALUES ($1, $2, $3, 'f', 't')
 916            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 917            SET
 918                accepted = 't',
 919                should_notify = 'f'
 920            WHERE
 921                NOT contacts.accepted AND
 922                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 923                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 924        ";
 925        let result = sqlx::query(query)
 926            .bind(id_a.0)
 927            .bind(id_b.0)
 928            .bind(a_to_b)
 929            .execute(&self.pool)
 930            .await?;
 931
 932        if result.rows_affected() == 1 {
 933            Ok(())
 934        } else {
 935            Err(anyhow!("contact already requested"))?
 936        }
 937    }
 938
 939    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 940        let (id_a, id_b) = if responder_id < requester_id {
 941            (responder_id, requester_id)
 942        } else {
 943            (requester_id, responder_id)
 944        };
 945        let query = "
 946            DELETE FROM contacts
 947            WHERE user_id_a = $1 AND user_id_b = $2;
 948        ";
 949        let result = sqlx::query(query)
 950            .bind(id_a.0)
 951            .bind(id_b.0)
 952            .execute(&self.pool)
 953            .await?;
 954
 955        if result.rows_affected() == 1 {
 956            Ok(())
 957        } else {
 958            Err(anyhow!("no such contact"))?
 959        }
 960    }
 961
 962    async fn dismiss_contact_notification(
 963        &self,
 964        user_id: UserId,
 965        contact_user_id: UserId,
 966    ) -> Result<()> {
 967        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 968            (user_id, contact_user_id, true)
 969        } else {
 970            (contact_user_id, user_id, false)
 971        };
 972
 973        let query = "
 974            UPDATE contacts
 975            SET should_notify = 'f'
 976            WHERE
 977                user_id_a = $1 AND user_id_b = $2 AND
 978                (
 979                    (a_to_b = $3 AND accepted) OR
 980                    (a_to_b != $3 AND NOT accepted)
 981                );
 982        ";
 983
 984        let result = sqlx::query(query)
 985            .bind(id_a.0)
 986            .bind(id_b.0)
 987            .bind(a_to_b)
 988            .execute(&self.pool)
 989            .await?;
 990
 991        if result.rows_affected() == 0 {
 992            Err(anyhow!("no such contact request"))?;
 993        }
 994
 995        Ok(())
 996    }
 997
 998    async fn respond_to_contact_request(
 999        &self,
1000        responder_id: UserId,
1001        requester_id: UserId,
1002        accept: bool,
1003    ) -> Result<()> {
1004        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1005            (responder_id, requester_id, false)
1006        } else {
1007            (requester_id, responder_id, true)
1008        };
1009        let result = if accept {
1010            let query = "
1011                UPDATE contacts
1012                SET accepted = 't', should_notify = 't'
1013                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1014            ";
1015            sqlx::query(query)
1016                .bind(id_a.0)
1017                .bind(id_b.0)
1018                .bind(a_to_b)
1019                .execute(&self.pool)
1020                .await?
1021        } else {
1022            let query = "
1023                DELETE FROM contacts
1024                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1025            ";
1026            sqlx::query(query)
1027                .bind(id_a.0)
1028                .bind(id_b.0)
1029                .bind(a_to_b)
1030                .execute(&self.pool)
1031                .await?
1032        };
1033        if result.rows_affected() == 1 {
1034            Ok(())
1035        } else {
1036            Err(anyhow!("no such contact request"))?
1037        }
1038    }
1039
1040    // access tokens
1041
1042    async fn create_access_token_hash(
1043        &self,
1044        user_id: UserId,
1045        access_token_hash: &str,
1046        max_access_token_count: usize,
1047    ) -> Result<()> {
1048        let insert_query = "
1049            INSERT INTO access_tokens (user_id, hash)
1050            VALUES ($1, $2);
1051        ";
1052        let cleanup_query = "
1053            DELETE FROM access_tokens
1054            WHERE id IN (
1055                SELECT id from access_tokens
1056                WHERE user_id = $1
1057                ORDER BY id DESC
1058                OFFSET $3
1059            )
1060        ";
1061
1062        let mut tx = self.pool.begin().await?;
1063        sqlx::query(insert_query)
1064            .bind(user_id.0)
1065            .bind(access_token_hash)
1066            .execute(&mut tx)
1067            .await?;
1068        sqlx::query(cleanup_query)
1069            .bind(user_id.0)
1070            .bind(access_token_hash)
1071            .bind(max_access_token_count as i32)
1072            .execute(&mut tx)
1073            .await?;
1074        Ok(tx.commit().await?)
1075    }
1076
1077    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1078        let query = "
1079            SELECT hash
1080            FROM access_tokens
1081            WHERE user_id = $1
1082            ORDER BY id DESC
1083        ";
1084        Ok(sqlx::query_scalar(query)
1085            .bind(user_id.0)
1086            .fetch_all(&self.pool)
1087            .await?)
1088    }
1089
1090    // orgs
1091
1092    #[allow(unused)] // Help rust-analyzer
1093    #[cfg(any(test, feature = "seed-support"))]
1094    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1095        let query = "
1096            SELECT *
1097            FROM orgs
1098            WHERE slug = $1
1099        ";
1100        Ok(sqlx::query_as(query)
1101            .bind(slug)
1102            .fetch_optional(&self.pool)
1103            .await?)
1104    }
1105
1106    #[cfg(any(test, feature = "seed-support"))]
1107    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1108        let query = "
1109            INSERT INTO orgs (name, slug)
1110            VALUES ($1, $2)
1111            RETURNING id
1112        ";
1113        Ok(sqlx::query_scalar(query)
1114            .bind(name)
1115            .bind(slug)
1116            .fetch_one(&self.pool)
1117            .await
1118            .map(OrgId)?)
1119    }
1120
1121    #[cfg(any(test, feature = "seed-support"))]
1122    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1123        let query = "
1124            INSERT INTO org_memberships (org_id, user_id, admin)
1125            VALUES ($1, $2, $3)
1126            ON CONFLICT DO NOTHING
1127        ";
1128        Ok(sqlx::query(query)
1129            .bind(org_id.0)
1130            .bind(user_id.0)
1131            .bind(is_admin)
1132            .execute(&self.pool)
1133            .await
1134            .map(drop)?)
1135    }
1136
1137    // channels
1138
1139    #[cfg(any(test, feature = "seed-support"))]
1140    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1141        let query = "
1142            INSERT INTO channels (owner_id, owner_is_user, name)
1143            VALUES ($1, false, $2)
1144            RETURNING id
1145        ";
1146        Ok(sqlx::query_scalar(query)
1147            .bind(org_id.0)
1148            .bind(name)
1149            .fetch_one(&self.pool)
1150            .await
1151            .map(ChannelId)?)
1152    }
1153
1154    #[allow(unused)] // Help rust-analyzer
1155    #[cfg(any(test, feature = "seed-support"))]
1156    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1157        let query = "
1158            SELECT *
1159            FROM channels
1160            WHERE
1161                channels.owner_is_user = false AND
1162                channels.owner_id = $1
1163        ";
1164        Ok(sqlx::query_as(query)
1165            .bind(org_id.0)
1166            .fetch_all(&self.pool)
1167            .await?)
1168    }
1169
1170    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1171        let query = "
1172            SELECT
1173                channels.*
1174            FROM
1175                channel_memberships, channels
1176            WHERE
1177                channel_memberships.user_id = $1 AND
1178                channel_memberships.channel_id = channels.id
1179        ";
1180        Ok(sqlx::query_as(query)
1181            .bind(user_id.0)
1182            .fetch_all(&self.pool)
1183            .await?)
1184    }
1185
1186    async fn can_user_access_channel(
1187        &self,
1188        user_id: UserId,
1189        channel_id: ChannelId,
1190    ) -> Result<bool> {
1191        let query = "
1192            SELECT id
1193            FROM channel_memberships
1194            WHERE user_id = $1 AND channel_id = $2
1195            LIMIT 1
1196        ";
1197        Ok(sqlx::query_scalar::<_, i32>(query)
1198            .bind(user_id.0)
1199            .bind(channel_id.0)
1200            .fetch_optional(&self.pool)
1201            .await
1202            .map(|e| e.is_some())?)
1203    }
1204
1205    #[cfg(any(test, feature = "seed-support"))]
1206    async fn add_channel_member(
1207        &self,
1208        channel_id: ChannelId,
1209        user_id: UserId,
1210        is_admin: bool,
1211    ) -> Result<()> {
1212        let query = "
1213            INSERT INTO channel_memberships (channel_id, user_id, admin)
1214            VALUES ($1, $2, $3)
1215            ON CONFLICT DO NOTHING
1216        ";
1217        Ok(sqlx::query(query)
1218            .bind(channel_id.0)
1219            .bind(user_id.0)
1220            .bind(is_admin)
1221            .execute(&self.pool)
1222            .await
1223            .map(drop)?)
1224    }
1225
1226    // messages
1227
1228    async fn create_channel_message(
1229        &self,
1230        channel_id: ChannelId,
1231        sender_id: UserId,
1232        body: &str,
1233        timestamp: OffsetDateTime,
1234        nonce: u128,
1235    ) -> Result<MessageId> {
1236        let query = "
1237            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1238            VALUES ($1, $2, $3, $4, $5)
1239            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1240            RETURNING id
1241        ";
1242        Ok(sqlx::query_scalar(query)
1243            .bind(channel_id.0)
1244            .bind(sender_id.0)
1245            .bind(body)
1246            .bind(timestamp)
1247            .bind(Uuid::from_u128(nonce))
1248            .fetch_one(&self.pool)
1249            .await
1250            .map(MessageId)?)
1251    }
1252
1253    async fn get_channel_messages(
1254        &self,
1255        channel_id: ChannelId,
1256        count: usize,
1257        before_id: Option<MessageId>,
1258    ) -> Result<Vec<ChannelMessage>> {
1259        let query = r#"
1260            SELECT * FROM (
1261                SELECT
1262                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1263                FROM
1264                    channel_messages
1265                WHERE
1266                    channel_id = $1 AND
1267                    id < $2
1268                ORDER BY id DESC
1269                LIMIT $3
1270            ) as recent_messages
1271            ORDER BY id ASC
1272        "#;
1273        Ok(sqlx::query_as(query)
1274            .bind(channel_id.0)
1275            .bind(before_id.unwrap_or(MessageId::MAX))
1276            .bind(count as i64)
1277            .fetch_all(&self.pool)
1278            .await?)
1279    }
1280
1281    #[cfg(test)]
1282    async fn teardown(&self, url: &str) {
1283        use util::ResultExt;
1284
1285        let query = "
1286            SELECT pg_terminate_backend(pg_stat_activity.pid)
1287            FROM pg_stat_activity
1288            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1289        ";
1290        sqlx::query(query).execute(&self.pool).await.log_err();
1291        self.pool.close().await;
1292        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1293            .await
1294            .log_err();
1295    }
1296
1297    #[cfg(test)]
1298    fn as_fake(&self) -> Option<&tests::FakeDb> {
1299        None
1300    }
1301}
1302
1303macro_rules! id_type {
1304    ($name:ident) => {
1305        #[derive(
1306            Clone,
1307            Copy,
1308            Debug,
1309            Default,
1310            PartialEq,
1311            Eq,
1312            PartialOrd,
1313            Ord,
1314            Hash,
1315            sqlx::Type,
1316            Serialize,
1317            Deserialize,
1318        )]
1319        #[sqlx(transparent)]
1320        #[serde(transparent)]
1321        pub struct $name(pub i32);
1322
1323        impl $name {
1324            #[allow(unused)]
1325            pub const MAX: Self = Self(i32::MAX);
1326
1327            #[allow(unused)]
1328            pub fn from_proto(value: u64) -> Self {
1329                Self(value as i32)
1330            }
1331
1332            #[allow(unused)]
1333            pub fn to_proto(&self) -> u64 {
1334                self.0 as u64
1335            }
1336        }
1337
1338        impl std::fmt::Display for $name {
1339            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1340                self.0.fmt(f)
1341            }
1342        }
1343    };
1344}
1345
1346id_type!(UserId);
1347#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1348pub struct User {
1349    pub id: UserId,
1350    pub github_login: String,
1351    pub email_address: Option<String>,
1352    pub admin: bool,
1353    pub invite_code: Option<String>,
1354    pub invite_count: i32,
1355    pub connected_once: bool,
1356}
1357
1358id_type!(ProjectId);
1359#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1360pub struct Project {
1361    pub id: ProjectId,
1362    pub host_user_id: UserId,
1363    pub unregistered: bool,
1364}
1365
1366#[derive(Clone, Debug, PartialEq, Serialize)]
1367pub struct UserActivitySummary {
1368    pub id: UserId,
1369    pub github_login: String,
1370    pub project_activity: Vec<ProjectActivitySummary>,
1371}
1372
1373#[derive(Clone, Debug, PartialEq, Serialize)]
1374pub struct ProjectActivitySummary {
1375    id: ProjectId,
1376    duration: Duration,
1377    max_collaborators: usize,
1378}
1379
1380#[derive(Clone, Debug, PartialEq, Serialize)]
1381pub struct UserActivityPeriod {
1382    project_id: ProjectId,
1383    #[serde(with = "time::serde::iso8601")]
1384    start: OffsetDateTime,
1385    #[serde(with = "time::serde::iso8601")]
1386    end: OffsetDateTime,
1387    extensions: HashMap<String, usize>,
1388}
1389
1390id_type!(OrgId);
1391#[derive(FromRow)]
1392pub struct Org {
1393    pub id: OrgId,
1394    pub name: String,
1395    pub slug: String,
1396}
1397
1398id_type!(ChannelId);
1399#[derive(Clone, Debug, FromRow, Serialize)]
1400pub struct Channel {
1401    pub id: ChannelId,
1402    pub name: String,
1403    pub owner_id: i32,
1404    pub owner_is_user: bool,
1405}
1406
1407id_type!(MessageId);
1408#[derive(Clone, Debug, FromRow)]
1409pub struct ChannelMessage {
1410    pub id: MessageId,
1411    pub channel_id: ChannelId,
1412    pub sender_id: UserId,
1413    pub body: String,
1414    pub sent_at: OffsetDateTime,
1415    pub nonce: Uuid,
1416}
1417
1418#[derive(Clone, Debug, PartialEq, Eq)]
1419pub enum Contact {
1420    Accepted {
1421        user_id: UserId,
1422        should_notify: bool,
1423    },
1424    Outgoing {
1425        user_id: UserId,
1426    },
1427    Incoming {
1428        user_id: UserId,
1429        should_notify: bool,
1430    },
1431}
1432
1433impl Contact {
1434    pub fn user_id(&self) -> UserId {
1435        match self {
1436            Contact::Accepted { user_id, .. } => *user_id,
1437            Contact::Outgoing { user_id } => *user_id,
1438            Contact::Incoming { user_id, .. } => *user_id,
1439        }
1440    }
1441}
1442
1443#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1444pub struct IncomingContactRequest {
1445    pub requester_id: UserId,
1446    pub should_notify: bool,
1447}
1448
1449fn fuzzy_like_string(string: &str) -> String {
1450    let mut result = String::with_capacity(string.len() * 2 + 1);
1451    for c in string.chars() {
1452        if c.is_alphanumeric() {
1453            result.push('%');
1454            result.push(c);
1455        }
1456    }
1457    result.push('%');
1458    result
1459}
1460
1461#[cfg(test)]
1462pub mod tests {
1463    use super::*;
1464    use anyhow::anyhow;
1465    use collections::BTreeMap;
1466    use gpui::executor::{Background, Deterministic};
1467    use lazy_static::lazy_static;
1468    use parking_lot::Mutex;
1469    use rand::prelude::*;
1470    use sqlx::{
1471        migrate::{MigrateDatabase, Migrator},
1472        Postgres,
1473    };
1474    use std::{path::Path, sync::Arc};
1475    use util::post_inc;
1476
1477    #[tokio::test(flavor = "multi_thread")]
1478    async fn test_get_users_by_ids() {
1479        for test_db in [
1480            TestDb::postgres().await,
1481            TestDb::fake(build_background_executor()),
1482        ] {
1483            let db = test_db.db();
1484
1485            let user = db.create_user("user", None, false).await.unwrap();
1486            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1487            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1488            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1489
1490            assert_eq!(
1491                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1492                    .await
1493                    .unwrap(),
1494                vec![
1495                    User {
1496                        id: user,
1497                        github_login: "user".to_string(),
1498                        admin: false,
1499                        ..Default::default()
1500                    },
1501                    User {
1502                        id: friend1,
1503                        github_login: "friend-1".to_string(),
1504                        admin: false,
1505                        ..Default::default()
1506                    },
1507                    User {
1508                        id: friend2,
1509                        github_login: "friend-2".to_string(),
1510                        admin: false,
1511                        ..Default::default()
1512                    },
1513                    User {
1514                        id: friend3,
1515                        github_login: "friend-3".to_string(),
1516                        admin: false,
1517                        ..Default::default()
1518                    }
1519                ]
1520            );
1521        }
1522    }
1523
1524    #[tokio::test(flavor = "multi_thread")]
1525    async fn test_create_users() {
1526        let db = TestDb::postgres().await;
1527        let db = db.db();
1528
1529        // Create the first batch of users, ensuring invite counts are assigned
1530        // correctly and the respective invite codes are unique.
1531        let user_ids_batch_1 = db
1532            .create_users(vec![
1533                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1534                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1535                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1536            ])
1537            .await
1538            .unwrap();
1539        assert_eq!(user_ids_batch_1.len(), 3);
1540
1541        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1542        assert_eq!(users.len(), 3);
1543        assert_eq!(users[0].github_login, "user1");
1544        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1545        assert_eq!(users[0].invite_count, 5);
1546        assert_eq!(users[1].github_login, "user2");
1547        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1548        assert_eq!(users[1].invite_count, 4);
1549        assert_eq!(users[2].github_login, "user3");
1550        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1551        assert_eq!(users[2].invite_count, 3);
1552
1553        let invite_code_1 = users[0].invite_code.clone().unwrap();
1554        let invite_code_2 = users[1].invite_code.clone().unwrap();
1555        let invite_code_3 = users[2].invite_code.clone().unwrap();
1556        assert_ne!(invite_code_1, invite_code_2);
1557        assert_ne!(invite_code_1, invite_code_3);
1558        assert_ne!(invite_code_2, invite_code_3);
1559
1560        // Create the second batch of users and include a user that is already in the database, ensuring
1561        // the invite count for the existing user is updated without changing their invite code.
1562        let user_ids_batch_2 = db
1563            .create_users(vec![
1564                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1565                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1566            ])
1567            .await
1568            .unwrap();
1569        assert_eq!(user_ids_batch_2.len(), 2);
1570        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1571
1572        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1573        assert_eq!(users.len(), 2);
1574        assert_eq!(users[0].github_login, "user2");
1575        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1576        assert_eq!(users[0].invite_count, 10);
1577        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1578        assert_eq!(users[1].github_login, "user4");
1579        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1580        assert_eq!(users[1].invite_count, 2);
1581
1582        let invite_code_4 = users[1].invite_code.clone().unwrap();
1583        assert_ne!(invite_code_4, invite_code_1);
1584        assert_ne!(invite_code_4, invite_code_2);
1585        assert_ne!(invite_code_4, invite_code_3);
1586    }
1587
1588    #[tokio::test(flavor = "multi_thread")]
1589    async fn test_worktree_extensions() {
1590        let test_db = TestDb::postgres().await;
1591        let db = test_db.db();
1592
1593        let user = db.create_user("user_1", None, false).await.unwrap();
1594        let project = db.register_project(user).await.unwrap();
1595
1596        db.update_worktree_extensions(project, 100, Default::default())
1597            .await
1598            .unwrap();
1599        db.update_worktree_extensions(
1600            project,
1601            100,
1602            [("rs".to_string(), 5), ("md".to_string(), 3)]
1603                .into_iter()
1604                .collect(),
1605        )
1606        .await
1607        .unwrap();
1608        db.update_worktree_extensions(
1609            project,
1610            100,
1611            [("rs".to_string(), 6), ("md".to_string(), 5)]
1612                .into_iter()
1613                .collect(),
1614        )
1615        .await
1616        .unwrap();
1617        db.update_worktree_extensions(
1618            project,
1619            101,
1620            [("ts".to_string(), 2), ("md".to_string(), 1)]
1621                .into_iter()
1622                .collect(),
1623        )
1624        .await
1625        .unwrap();
1626
1627        assert_eq!(
1628            db.get_project_extensions(project).await.unwrap(),
1629            [
1630                (
1631                    100,
1632                    [("rs".into(), 6), ("md".into(), 5),]
1633                        .into_iter()
1634                        .collect::<HashMap<_, _>>()
1635                ),
1636                (
1637                    101,
1638                    [("ts".into(), 2), ("md".into(), 1),]
1639                        .into_iter()
1640                        .collect::<HashMap<_, _>>()
1641                )
1642            ]
1643            .into_iter()
1644            .collect()
1645        );
1646    }
1647
1648    #[tokio::test(flavor = "multi_thread")]
1649    async fn test_user_activity() {
1650        let test_db = TestDb::postgres().await;
1651        let db = test_db.db();
1652
1653        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1654        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1655        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1656        let project_1 = db.register_project(user_1).await.unwrap();
1657        db.update_worktree_extensions(
1658            project_1,
1659            1,
1660            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1661        )
1662        .await
1663        .unwrap();
1664        let project_2 = db.register_project(user_2).await.unwrap();
1665        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1666
1667        // User 2 opens a project
1668        let t1 = t0 + Duration::from_secs(10);
1669        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1670            .await
1671            .unwrap();
1672
1673        let t2 = t1 + Duration::from_secs(10);
1674        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1675            .await
1676            .unwrap();
1677
1678        // User 1 joins the project
1679        let t3 = t2 + Duration::from_secs(10);
1680        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1681            .await
1682            .unwrap();
1683
1684        // User 1 opens another project
1685        let t4 = t3 + Duration::from_secs(10);
1686        db.record_user_activity(
1687            t3..t4,
1688            &[
1689                (user_2, project_2),
1690                (user_1, project_2),
1691                (user_1, project_1),
1692            ],
1693        )
1694        .await
1695        .unwrap();
1696
1697        // User 3 joins that project
1698        let t5 = t4 + Duration::from_secs(10);
1699        db.record_user_activity(
1700            t4..t5,
1701            &[
1702                (user_2, project_2),
1703                (user_1, project_2),
1704                (user_1, project_1),
1705                (user_3, project_1),
1706            ],
1707        )
1708        .await
1709        .unwrap();
1710
1711        // User 2 leaves
1712        let t6 = t5 + Duration::from_secs(5);
1713        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1714            .await
1715            .unwrap();
1716
1717        let t7 = t6 + Duration::from_secs(60);
1718        let t8 = t7 + Duration::from_secs(10);
1719        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1720            .await
1721            .unwrap();
1722
1723        assert_eq!(
1724            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1725            &[
1726                UserActivitySummary {
1727                    id: user_1,
1728                    github_login: "user_1".to_string(),
1729                    project_activity: vec![
1730                        ProjectActivitySummary {
1731                            id: project_1,
1732                            duration: Duration::from_secs(25),
1733                            max_collaborators: 2
1734                        },
1735                        ProjectActivitySummary {
1736                            id: project_2,
1737                            duration: Duration::from_secs(30),
1738                            max_collaborators: 2
1739                        }
1740                    ]
1741                },
1742                UserActivitySummary {
1743                    id: user_2,
1744                    github_login: "user_2".to_string(),
1745                    project_activity: vec![ProjectActivitySummary {
1746                        id: project_2,
1747                        duration: Duration::from_secs(50),
1748                        max_collaborators: 2
1749                    }]
1750                },
1751                UserActivitySummary {
1752                    id: user_3,
1753                    github_login: "user_3".to_string(),
1754                    project_activity: vec![ProjectActivitySummary {
1755                        id: project_1,
1756                        duration: Duration::from_secs(15),
1757                        max_collaborators: 2
1758                    }]
1759                },
1760            ]
1761        );
1762
1763        assert_eq!(
1764            db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1765                .await
1766                .unwrap(),
1767            0
1768        );
1769        assert_eq!(
1770            db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1771                .await
1772                .unwrap(),
1773            0
1774        );
1775        assert_eq!(
1776            db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1777                .await
1778                .unwrap(),
1779            1
1780        );
1781        assert_eq!(
1782            db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1783                .await
1784                .unwrap(),
1785            1
1786        );
1787        assert_eq!(
1788            db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1789                .await
1790                .unwrap(),
1791            2
1792        );
1793        assert_eq!(
1794            db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1795                .await
1796                .unwrap(),
1797            2
1798        );
1799        assert_eq!(
1800            db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1801                .await
1802                .unwrap(),
1803            3
1804        );
1805        assert_eq!(
1806            db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1807                .await
1808                .unwrap(),
1809            3
1810        );
1811        assert_eq!(
1812            db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1813                .await
1814                .unwrap(),
1815            1
1816        );
1817        assert_eq!(
1818            db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1819                .await
1820                .unwrap(),
1821            0
1822        );
1823
1824        assert_eq!(
1825            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1826            &[
1827                UserActivityPeriod {
1828                    project_id: project_1,
1829                    start: t3,
1830                    end: t6,
1831                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1832                },
1833                UserActivityPeriod {
1834                    project_id: project_2,
1835                    start: t3,
1836                    end: t5,
1837                    extensions: Default::default(),
1838                },
1839            ]
1840        );
1841        assert_eq!(
1842            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1843            &[
1844                UserActivityPeriod {
1845                    project_id: project_2,
1846                    start: t2,
1847                    end: t5,
1848                    extensions: Default::default(),
1849                },
1850                UserActivityPeriod {
1851                    project_id: project_1,
1852                    start: t3,
1853                    end: t6,
1854                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1855                },
1856                UserActivityPeriod {
1857                    project_id: project_1,
1858                    start: t7,
1859                    end: t8,
1860                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1861                },
1862            ]
1863        );
1864    }
1865
1866    #[tokio::test(flavor = "multi_thread")]
1867    async fn test_recent_channel_messages() {
1868        for test_db in [
1869            TestDb::postgres().await,
1870            TestDb::fake(build_background_executor()),
1871        ] {
1872            let db = test_db.db();
1873            let user = db.create_user("user", None, false).await.unwrap();
1874            let org = db.create_org("org", "org").await.unwrap();
1875            let channel = db.create_org_channel(org, "channel").await.unwrap();
1876            for i in 0..10 {
1877                db.create_channel_message(
1878                    channel,
1879                    user,
1880                    &i.to_string(),
1881                    OffsetDateTime::now_utc(),
1882                    i,
1883                )
1884                .await
1885                .unwrap();
1886            }
1887
1888            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1889            assert_eq!(
1890                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1891                ["5", "6", "7", "8", "9"]
1892            );
1893
1894            let prev_messages = db
1895                .get_channel_messages(channel, 4, Some(messages[0].id))
1896                .await
1897                .unwrap();
1898            assert_eq!(
1899                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1900                ["1", "2", "3", "4"]
1901            );
1902        }
1903    }
1904
1905    #[tokio::test(flavor = "multi_thread")]
1906    async fn test_channel_message_nonces() {
1907        for test_db in [
1908            TestDb::postgres().await,
1909            TestDb::fake(build_background_executor()),
1910        ] {
1911            let db = test_db.db();
1912            let user = db.create_user("user", None, false).await.unwrap();
1913            let org = db.create_org("org", "org").await.unwrap();
1914            let channel = db.create_org_channel(org, "channel").await.unwrap();
1915
1916            let msg1_id = db
1917                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1918                .await
1919                .unwrap();
1920            let msg2_id = db
1921                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1922                .await
1923                .unwrap();
1924            let msg3_id = db
1925                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1926                .await
1927                .unwrap();
1928            let msg4_id = db
1929                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1930                .await
1931                .unwrap();
1932
1933            assert_ne!(msg1_id, msg2_id);
1934            assert_eq!(msg1_id, msg3_id);
1935            assert_eq!(msg2_id, msg4_id);
1936        }
1937    }
1938
1939    #[tokio::test(flavor = "multi_thread")]
1940    async fn test_create_access_tokens() {
1941        let test_db = TestDb::postgres().await;
1942        let db = test_db.db();
1943        let user = db.create_user("the-user", None, false).await.unwrap();
1944
1945        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1946        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1947        assert_eq!(
1948            db.get_access_token_hashes(user).await.unwrap(),
1949            &["h2".to_string(), "h1".to_string()]
1950        );
1951
1952        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1953        assert_eq!(
1954            db.get_access_token_hashes(user).await.unwrap(),
1955            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1956        );
1957
1958        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1959        assert_eq!(
1960            db.get_access_token_hashes(user).await.unwrap(),
1961            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1962        );
1963
1964        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1965        assert_eq!(
1966            db.get_access_token_hashes(user).await.unwrap(),
1967            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1968        );
1969    }
1970
1971    #[test]
1972    fn test_fuzzy_like_string() {
1973        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1974        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1975        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1976    }
1977
1978    #[tokio::test(flavor = "multi_thread")]
1979    async fn test_fuzzy_search_users() {
1980        let test_db = TestDb::postgres().await;
1981        let db = test_db.db();
1982        for github_login in [
1983            "California",
1984            "colorado",
1985            "oregon",
1986            "washington",
1987            "florida",
1988            "delaware",
1989            "rhode-island",
1990        ] {
1991            db.create_user(github_login, None, false).await.unwrap();
1992        }
1993
1994        assert_eq!(
1995            fuzzy_search_user_names(db, "clr").await,
1996            &["colorado", "California"]
1997        );
1998        assert_eq!(
1999            fuzzy_search_user_names(db, "ro").await,
2000            &["rhode-island", "colorado", "oregon"],
2001        );
2002
2003        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2004            db.fuzzy_search_users(query, 10)
2005                .await
2006                .unwrap()
2007                .into_iter()
2008                .map(|user| user.github_login)
2009                .collect::<Vec<_>>()
2010        }
2011    }
2012
2013    #[tokio::test(flavor = "multi_thread")]
2014    async fn test_add_contacts() {
2015        for test_db in [
2016            TestDb::postgres().await,
2017            TestDb::fake(build_background_executor()),
2018        ] {
2019            let db = test_db.db();
2020
2021            let user_1 = db.create_user("user1", None, false).await.unwrap();
2022            let user_2 = db.create_user("user2", None, false).await.unwrap();
2023            let user_3 = db.create_user("user3", None, false).await.unwrap();
2024
2025            // User starts with no contacts
2026            assert_eq!(
2027                db.get_contacts(user_1).await.unwrap(),
2028                vec![Contact::Accepted {
2029                    user_id: user_1,
2030                    should_notify: false
2031                }],
2032            );
2033
2034            // User requests a contact. Both users see the pending request.
2035            db.send_contact_request(user_1, user_2).await.unwrap();
2036            assert!(!db.has_contact(user_1, user_2).await.unwrap());
2037            assert!(!db.has_contact(user_2, user_1).await.unwrap());
2038            assert_eq!(
2039                db.get_contacts(user_1).await.unwrap(),
2040                &[
2041                    Contact::Accepted {
2042                        user_id: user_1,
2043                        should_notify: false
2044                    },
2045                    Contact::Outgoing { user_id: user_2 }
2046                ],
2047            );
2048            assert_eq!(
2049                db.get_contacts(user_2).await.unwrap(),
2050                &[
2051                    Contact::Incoming {
2052                        user_id: user_1,
2053                        should_notify: true
2054                    },
2055                    Contact::Accepted {
2056                        user_id: user_2,
2057                        should_notify: false
2058                    },
2059                ]
2060            );
2061
2062            // User 2 dismisses the contact request notification without accepting or rejecting.
2063            // We shouldn't notify them again.
2064            db.dismiss_contact_notification(user_1, user_2)
2065                .await
2066                .unwrap_err();
2067            db.dismiss_contact_notification(user_2, user_1)
2068                .await
2069                .unwrap();
2070            assert_eq!(
2071                db.get_contacts(user_2).await.unwrap(),
2072                &[
2073                    Contact::Incoming {
2074                        user_id: user_1,
2075                        should_notify: false
2076                    },
2077                    Contact::Accepted {
2078                        user_id: user_2,
2079                        should_notify: false
2080                    },
2081                ]
2082            );
2083
2084            // User can't accept their own contact request
2085            db.respond_to_contact_request(user_1, user_2, true)
2086                .await
2087                .unwrap_err();
2088
2089            // User accepts a contact request. Both users see the contact.
2090            db.respond_to_contact_request(user_2, user_1, true)
2091                .await
2092                .unwrap();
2093            assert_eq!(
2094                db.get_contacts(user_1).await.unwrap(),
2095                &[
2096                    Contact::Accepted {
2097                        user_id: user_1,
2098                        should_notify: false
2099                    },
2100                    Contact::Accepted {
2101                        user_id: user_2,
2102                        should_notify: true
2103                    }
2104                ],
2105            );
2106            assert!(db.has_contact(user_1, user_2).await.unwrap());
2107            assert!(db.has_contact(user_2, user_1).await.unwrap());
2108            assert_eq!(
2109                db.get_contacts(user_2).await.unwrap(),
2110                &[
2111                    Contact::Accepted {
2112                        user_id: user_1,
2113                        should_notify: false,
2114                    },
2115                    Contact::Accepted {
2116                        user_id: user_2,
2117                        should_notify: false,
2118                    },
2119                ]
2120            );
2121
2122            // Users cannot re-request existing contacts.
2123            db.send_contact_request(user_1, user_2).await.unwrap_err();
2124            db.send_contact_request(user_2, user_1).await.unwrap_err();
2125
2126            // Users can't dismiss notifications of them accepting other users' requests.
2127            db.dismiss_contact_notification(user_2, user_1)
2128                .await
2129                .unwrap_err();
2130            assert_eq!(
2131                db.get_contacts(user_1).await.unwrap(),
2132                &[
2133                    Contact::Accepted {
2134                        user_id: user_1,
2135                        should_notify: false
2136                    },
2137                    Contact::Accepted {
2138                        user_id: user_2,
2139                        should_notify: true,
2140                    },
2141                ]
2142            );
2143
2144            // Users can dismiss notifications of other users accepting their requests.
2145            db.dismiss_contact_notification(user_1, user_2)
2146                .await
2147                .unwrap();
2148            assert_eq!(
2149                db.get_contacts(user_1).await.unwrap(),
2150                &[
2151                    Contact::Accepted {
2152                        user_id: user_1,
2153                        should_notify: false
2154                    },
2155                    Contact::Accepted {
2156                        user_id: user_2,
2157                        should_notify: false,
2158                    },
2159                ]
2160            );
2161
2162            // Users send each other concurrent contact requests and
2163            // see that they are immediately accepted.
2164            db.send_contact_request(user_1, user_3).await.unwrap();
2165            db.send_contact_request(user_3, user_1).await.unwrap();
2166            assert_eq!(
2167                db.get_contacts(user_1).await.unwrap(),
2168                &[
2169                    Contact::Accepted {
2170                        user_id: user_1,
2171                        should_notify: false
2172                    },
2173                    Contact::Accepted {
2174                        user_id: user_2,
2175                        should_notify: false,
2176                    },
2177                    Contact::Accepted {
2178                        user_id: user_3,
2179                        should_notify: false
2180                    },
2181                ]
2182            );
2183            assert_eq!(
2184                db.get_contacts(user_3).await.unwrap(),
2185                &[
2186                    Contact::Accepted {
2187                        user_id: user_1,
2188                        should_notify: false
2189                    },
2190                    Contact::Accepted {
2191                        user_id: user_3,
2192                        should_notify: false
2193                    }
2194                ],
2195            );
2196
2197            // User declines a contact request. Both users see that it is gone.
2198            db.send_contact_request(user_2, user_3).await.unwrap();
2199            db.respond_to_contact_request(user_3, user_2, false)
2200                .await
2201                .unwrap();
2202            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2203            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2204            assert_eq!(
2205                db.get_contacts(user_2).await.unwrap(),
2206                &[
2207                    Contact::Accepted {
2208                        user_id: user_1,
2209                        should_notify: false
2210                    },
2211                    Contact::Accepted {
2212                        user_id: user_2,
2213                        should_notify: false
2214                    }
2215                ]
2216            );
2217            assert_eq!(
2218                db.get_contacts(user_3).await.unwrap(),
2219                &[
2220                    Contact::Accepted {
2221                        user_id: user_1,
2222                        should_notify: false
2223                    },
2224                    Contact::Accepted {
2225                        user_id: user_3,
2226                        should_notify: false
2227                    }
2228                ],
2229            );
2230        }
2231    }
2232
2233    #[tokio::test(flavor = "multi_thread")]
2234    async fn test_invite_codes() {
2235        let postgres = TestDb::postgres().await;
2236        let db = postgres.db();
2237        let user1 = db.create_user("user-1", None, false).await.unwrap();
2238
2239        // Initially, user 1 has no invite code
2240        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2241
2242        // Setting invite count to 0 when no code is assigned does not assign a new code
2243        db.set_invite_count(user1, 0).await.unwrap();
2244        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2245
2246        // User 1 creates an invite code that can be used twice.
2247        db.set_invite_count(user1, 2).await.unwrap();
2248        let (invite_code, invite_count) =
2249            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2250        assert_eq!(invite_count, 2);
2251
2252        // User 2 redeems the invite code and becomes a contact of user 1.
2253        let user2 = db
2254            .redeem_invite_code(&invite_code, "user-2", None)
2255            .await
2256            .unwrap();
2257        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2258        assert_eq!(invite_count, 1);
2259        assert_eq!(
2260            db.get_contacts(user1).await.unwrap(),
2261            [
2262                Contact::Accepted {
2263                    user_id: user1,
2264                    should_notify: false
2265                },
2266                Contact::Accepted {
2267                    user_id: user2,
2268                    should_notify: true
2269                }
2270            ]
2271        );
2272        assert_eq!(
2273            db.get_contacts(user2).await.unwrap(),
2274            [
2275                Contact::Accepted {
2276                    user_id: user1,
2277                    should_notify: false
2278                },
2279                Contact::Accepted {
2280                    user_id: user2,
2281                    should_notify: false
2282                }
2283            ]
2284        );
2285
2286        // User 3 redeems the invite code and becomes a contact of user 1.
2287        let user3 = db
2288            .redeem_invite_code(&invite_code, "user-3", None)
2289            .await
2290            .unwrap();
2291        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2292        assert_eq!(invite_count, 0);
2293        assert_eq!(
2294            db.get_contacts(user1).await.unwrap(),
2295            [
2296                Contact::Accepted {
2297                    user_id: user1,
2298                    should_notify: false
2299                },
2300                Contact::Accepted {
2301                    user_id: user2,
2302                    should_notify: true
2303                },
2304                Contact::Accepted {
2305                    user_id: user3,
2306                    should_notify: true
2307                }
2308            ]
2309        );
2310        assert_eq!(
2311            db.get_contacts(user3).await.unwrap(),
2312            [
2313                Contact::Accepted {
2314                    user_id: user1,
2315                    should_notify: false
2316                },
2317                Contact::Accepted {
2318                    user_id: user3,
2319                    should_notify: false
2320                },
2321            ]
2322        );
2323
2324        // Trying to reedem the code for the third time results in an error.
2325        db.redeem_invite_code(&invite_code, "user-4", None)
2326            .await
2327            .unwrap_err();
2328
2329        // Invite count can be updated after the code has been created.
2330        db.set_invite_count(user1, 2).await.unwrap();
2331        let (latest_code, invite_count) =
2332            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2333        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2334        assert_eq!(invite_count, 2);
2335
2336        // User 4 can now redeem the invite code and becomes a contact of user 1.
2337        let user4 = db
2338            .redeem_invite_code(&invite_code, "user-4", None)
2339            .await
2340            .unwrap();
2341        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2342        assert_eq!(invite_count, 1);
2343        assert_eq!(
2344            db.get_contacts(user1).await.unwrap(),
2345            [
2346                Contact::Accepted {
2347                    user_id: user1,
2348                    should_notify: false
2349                },
2350                Contact::Accepted {
2351                    user_id: user2,
2352                    should_notify: true
2353                },
2354                Contact::Accepted {
2355                    user_id: user3,
2356                    should_notify: true
2357                },
2358                Contact::Accepted {
2359                    user_id: user4,
2360                    should_notify: true
2361                }
2362            ]
2363        );
2364        assert_eq!(
2365            db.get_contacts(user4).await.unwrap(),
2366            [
2367                Contact::Accepted {
2368                    user_id: user1,
2369                    should_notify: false
2370                },
2371                Contact::Accepted {
2372                    user_id: user4,
2373                    should_notify: false
2374                },
2375            ]
2376        );
2377
2378        // An existing user cannot redeem invite codes.
2379        db.redeem_invite_code(&invite_code, "user-2", None)
2380            .await
2381            .unwrap_err();
2382        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2383        assert_eq!(invite_count, 1);
2384    }
2385
2386    pub struct TestDb {
2387        pub db: Option<Arc<dyn Db>>,
2388        pub url: String,
2389    }
2390
2391    impl TestDb {
2392        pub async fn postgres() -> Self {
2393            lazy_static! {
2394                static ref LOCK: Mutex<()> = Mutex::new(());
2395            }
2396
2397            let _guard = LOCK.lock();
2398            let mut rng = StdRng::from_entropy();
2399            let name = format!("zed-test-{}", rng.gen::<u128>());
2400            let url = format!("postgres://postgres@localhost/{}", name);
2401            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2402            Postgres::create_database(&url)
2403                .await
2404                .expect("failed to create test db");
2405            let db = PostgresDb::new(&url, 5).await.unwrap();
2406            let migrator = Migrator::new(migrations_path).await.unwrap();
2407            migrator.run(&db.pool).await.unwrap();
2408            Self {
2409                db: Some(Arc::new(db)),
2410                url,
2411            }
2412        }
2413
2414        pub fn fake(background: Arc<Background>) -> Self {
2415            Self {
2416                db: Some(Arc::new(FakeDb::new(background))),
2417                url: Default::default(),
2418            }
2419        }
2420
2421        pub fn db(&self) -> &Arc<dyn Db> {
2422            self.db.as_ref().unwrap()
2423        }
2424    }
2425
2426    impl Drop for TestDb {
2427        fn drop(&mut self) {
2428            if let Some(db) = self.db.take() {
2429                futures::executor::block_on(db.teardown(&self.url));
2430            }
2431        }
2432    }
2433
2434    pub struct FakeDb {
2435        background: Arc<Background>,
2436        pub users: Mutex<BTreeMap<UserId, User>>,
2437        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2438        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2439        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2440        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2441        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2442        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2443        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2444        pub contacts: Mutex<Vec<FakeContact>>,
2445        next_channel_message_id: Mutex<i32>,
2446        next_user_id: Mutex<i32>,
2447        next_org_id: Mutex<i32>,
2448        next_channel_id: Mutex<i32>,
2449        next_project_id: Mutex<i32>,
2450    }
2451
2452    #[derive(Debug)]
2453    pub struct FakeContact {
2454        pub requester_id: UserId,
2455        pub responder_id: UserId,
2456        pub accepted: bool,
2457        pub should_notify: bool,
2458    }
2459
2460    impl FakeDb {
2461        pub fn new(background: Arc<Background>) -> Self {
2462            Self {
2463                background,
2464                users: Default::default(),
2465                next_user_id: Mutex::new(0),
2466                projects: Default::default(),
2467                worktree_extensions: Default::default(),
2468                next_project_id: Mutex::new(1),
2469                orgs: Default::default(),
2470                next_org_id: Mutex::new(1),
2471                org_memberships: Default::default(),
2472                channels: Default::default(),
2473                next_channel_id: Mutex::new(1),
2474                channel_memberships: Default::default(),
2475                channel_messages: Default::default(),
2476                next_channel_message_id: Mutex::new(1),
2477                contacts: Default::default(),
2478            }
2479        }
2480    }
2481
2482    #[async_trait]
2483    impl Db for FakeDb {
2484        async fn create_user(
2485            &self,
2486            github_login: &str,
2487            email_address: Option<&str>,
2488            admin: bool,
2489        ) -> Result<UserId> {
2490            self.background.simulate_random_delay().await;
2491
2492            let mut users = self.users.lock();
2493            if let Some(user) = users
2494                .values()
2495                .find(|user| user.github_login == github_login)
2496            {
2497                Ok(user.id)
2498            } else {
2499                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2500                users.insert(
2501                    user_id,
2502                    User {
2503                        id: user_id,
2504                        github_login: github_login.to_string(),
2505                        email_address: email_address.map(str::to_string),
2506                        admin,
2507                        invite_code: None,
2508                        invite_count: 0,
2509                        connected_once: false,
2510                    },
2511                );
2512                Ok(user_id)
2513            }
2514        }
2515
2516        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2517            unimplemented!()
2518        }
2519
2520        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2521            unimplemented!()
2522        }
2523
2524        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2525            unimplemented!()
2526        }
2527
2528        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2529            self.background.simulate_random_delay().await;
2530            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2531        }
2532
2533        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2534            self.background.simulate_random_delay().await;
2535            let users = self.users.lock();
2536            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2537        }
2538
2539        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2540            unimplemented!()
2541        }
2542
2543        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2544            self.background.simulate_random_delay().await;
2545            Ok(self
2546                .users
2547                .lock()
2548                .values()
2549                .find(|user| user.github_login == github_login)
2550                .cloned())
2551        }
2552
2553        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2554            unimplemented!()
2555        }
2556
2557        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2558            self.background.simulate_random_delay().await;
2559            let mut users = self.users.lock();
2560            let mut user = users
2561                .get_mut(&id)
2562                .ok_or_else(|| anyhow!("user not found"))?;
2563            user.connected_once = connected_once;
2564            Ok(())
2565        }
2566
2567        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2568            unimplemented!()
2569        }
2570
2571        // invite codes
2572
2573        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2574            unimplemented!()
2575        }
2576
2577        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2578            self.background.simulate_random_delay().await;
2579            Ok(None)
2580        }
2581
2582        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2583            unimplemented!()
2584        }
2585
2586        async fn redeem_invite_code(
2587            &self,
2588            _code: &str,
2589            _login: &str,
2590            _email_address: Option<&str>,
2591        ) -> Result<UserId> {
2592            unimplemented!()
2593        }
2594
2595        // projects
2596
2597        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2598            self.background.simulate_random_delay().await;
2599            if !self.users.lock().contains_key(&host_user_id) {
2600                Err(anyhow!("no such user"))?;
2601            }
2602
2603            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2604            self.projects.lock().insert(
2605                project_id,
2606                Project {
2607                    id: project_id,
2608                    host_user_id,
2609                    unregistered: false,
2610                },
2611            );
2612            Ok(project_id)
2613        }
2614
2615        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2616            self.background.simulate_random_delay().await;
2617            self.projects
2618                .lock()
2619                .get_mut(&project_id)
2620                .ok_or_else(|| anyhow!("no such project"))?
2621                .unregistered = true;
2622            Ok(())
2623        }
2624
2625        async fn update_worktree_extensions(
2626            &self,
2627            project_id: ProjectId,
2628            worktree_id: u64,
2629            extensions: HashMap<String, u32>,
2630        ) -> Result<()> {
2631            self.background.simulate_random_delay().await;
2632            if !self.projects.lock().contains_key(&project_id) {
2633                Err(anyhow!("no such project"))?;
2634            }
2635
2636            for (extension, count) in extensions {
2637                self.worktree_extensions
2638                    .lock()
2639                    .insert((project_id, worktree_id, extension), count);
2640            }
2641
2642            Ok(())
2643        }
2644
2645        async fn get_project_extensions(
2646            &self,
2647            _project_id: ProjectId,
2648        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2649            unimplemented!()
2650        }
2651
2652        async fn record_user_activity(
2653            &self,
2654            _time_period: Range<OffsetDateTime>,
2655            _active_projects: &[(UserId, ProjectId)],
2656        ) -> Result<()> {
2657            unimplemented!()
2658        }
2659
2660        async fn get_active_user_count(
2661            &self,
2662            _time_period: Range<OffsetDateTime>,
2663            _min_duration: Duration,
2664            _only_collaborative: bool,
2665        ) -> Result<usize> {
2666            unimplemented!()
2667        }
2668
2669        async fn get_top_users_activity_summary(
2670            &self,
2671            _time_period: Range<OffsetDateTime>,
2672            _limit: usize,
2673        ) -> Result<Vec<UserActivitySummary>> {
2674            unimplemented!()
2675        }
2676
2677        async fn get_user_activity_timeline(
2678            &self,
2679            _time_period: Range<OffsetDateTime>,
2680            _user_id: UserId,
2681        ) -> Result<Vec<UserActivityPeriod>> {
2682            unimplemented!()
2683        }
2684
2685        // contacts
2686
2687        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2688            self.background.simulate_random_delay().await;
2689            let mut contacts = vec![Contact::Accepted {
2690                user_id: id,
2691                should_notify: false,
2692            }];
2693
2694            for contact in self.contacts.lock().iter() {
2695                if contact.requester_id == id {
2696                    if contact.accepted {
2697                        contacts.push(Contact::Accepted {
2698                            user_id: contact.responder_id,
2699                            should_notify: contact.should_notify,
2700                        });
2701                    } else {
2702                        contacts.push(Contact::Outgoing {
2703                            user_id: contact.responder_id,
2704                        });
2705                    }
2706                } else if contact.responder_id == id {
2707                    if contact.accepted {
2708                        contacts.push(Contact::Accepted {
2709                            user_id: contact.requester_id,
2710                            should_notify: false,
2711                        });
2712                    } else {
2713                        contacts.push(Contact::Incoming {
2714                            user_id: contact.requester_id,
2715                            should_notify: contact.should_notify,
2716                        });
2717                    }
2718                }
2719            }
2720
2721            contacts.sort_unstable_by_key(|contact| contact.user_id());
2722            Ok(contacts)
2723        }
2724
2725        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2726            self.background.simulate_random_delay().await;
2727            Ok(self.contacts.lock().iter().any(|contact| {
2728                contact.accepted
2729                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2730                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2731            }))
2732        }
2733
2734        async fn send_contact_request(
2735            &self,
2736            requester_id: UserId,
2737            responder_id: UserId,
2738        ) -> Result<()> {
2739            self.background.simulate_random_delay().await;
2740            let mut contacts = self.contacts.lock();
2741            for contact in contacts.iter_mut() {
2742                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2743                    if contact.accepted {
2744                        Err(anyhow!("contact already exists"))?;
2745                    } else {
2746                        Err(anyhow!("contact already requested"))?;
2747                    }
2748                }
2749                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2750                    if contact.accepted {
2751                        Err(anyhow!("contact already exists"))?;
2752                    } else {
2753                        contact.accepted = true;
2754                        contact.should_notify = false;
2755                        return Ok(());
2756                    }
2757                }
2758            }
2759            contacts.push(FakeContact {
2760                requester_id,
2761                responder_id,
2762                accepted: false,
2763                should_notify: true,
2764            });
2765            Ok(())
2766        }
2767
2768        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2769            self.background.simulate_random_delay().await;
2770            self.contacts.lock().retain(|contact| {
2771                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2772            });
2773            Ok(())
2774        }
2775
2776        async fn dismiss_contact_notification(
2777            &self,
2778            user_id: UserId,
2779            contact_user_id: UserId,
2780        ) -> Result<()> {
2781            self.background.simulate_random_delay().await;
2782            let mut contacts = self.contacts.lock();
2783            for contact in contacts.iter_mut() {
2784                if contact.requester_id == contact_user_id
2785                    && contact.responder_id == user_id
2786                    && !contact.accepted
2787                {
2788                    contact.should_notify = false;
2789                    return Ok(());
2790                }
2791                if contact.requester_id == user_id
2792                    && contact.responder_id == contact_user_id
2793                    && contact.accepted
2794                {
2795                    contact.should_notify = false;
2796                    return Ok(());
2797                }
2798            }
2799            Err(anyhow!("no such notification"))?
2800        }
2801
2802        async fn respond_to_contact_request(
2803            &self,
2804            responder_id: UserId,
2805            requester_id: UserId,
2806            accept: bool,
2807        ) -> Result<()> {
2808            self.background.simulate_random_delay().await;
2809            let mut contacts = self.contacts.lock();
2810            for (ix, contact) in contacts.iter_mut().enumerate() {
2811                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2812                    if contact.accepted {
2813                        Err(anyhow!("contact already confirmed"))?;
2814                    }
2815                    if accept {
2816                        contact.accepted = true;
2817                        contact.should_notify = true;
2818                    } else {
2819                        contacts.remove(ix);
2820                    }
2821                    return Ok(());
2822                }
2823            }
2824            Err(anyhow!("no such contact request"))?
2825        }
2826
2827        async fn create_access_token_hash(
2828            &self,
2829            _user_id: UserId,
2830            _access_token_hash: &str,
2831            _max_access_token_count: usize,
2832        ) -> Result<()> {
2833            unimplemented!()
2834        }
2835
2836        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2837            unimplemented!()
2838        }
2839
2840        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2841            unimplemented!()
2842        }
2843
2844        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2845            self.background.simulate_random_delay().await;
2846            let mut orgs = self.orgs.lock();
2847            if orgs.values().any(|org| org.slug == slug) {
2848                Err(anyhow!("org already exists"))?
2849            } else {
2850                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2851                orgs.insert(
2852                    org_id,
2853                    Org {
2854                        id: org_id,
2855                        name: name.to_string(),
2856                        slug: slug.to_string(),
2857                    },
2858                );
2859                Ok(org_id)
2860            }
2861        }
2862
2863        async fn add_org_member(
2864            &self,
2865            org_id: OrgId,
2866            user_id: UserId,
2867            is_admin: bool,
2868        ) -> Result<()> {
2869            self.background.simulate_random_delay().await;
2870            if !self.orgs.lock().contains_key(&org_id) {
2871                Err(anyhow!("org does not exist"))?;
2872            }
2873            if !self.users.lock().contains_key(&user_id) {
2874                Err(anyhow!("user does not exist"))?;
2875            }
2876
2877            self.org_memberships
2878                .lock()
2879                .entry((org_id, user_id))
2880                .or_insert(is_admin);
2881            Ok(())
2882        }
2883
2884        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2885            self.background.simulate_random_delay().await;
2886            if !self.orgs.lock().contains_key(&org_id) {
2887                Err(anyhow!("org does not exist"))?;
2888            }
2889
2890            let mut channels = self.channels.lock();
2891            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2892            channels.insert(
2893                channel_id,
2894                Channel {
2895                    id: channel_id,
2896                    name: name.to_string(),
2897                    owner_id: org_id.0,
2898                    owner_is_user: false,
2899                },
2900            );
2901            Ok(channel_id)
2902        }
2903
2904        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2905            self.background.simulate_random_delay().await;
2906            Ok(self
2907                .channels
2908                .lock()
2909                .values()
2910                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2911                .cloned()
2912                .collect())
2913        }
2914
2915        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2916            self.background.simulate_random_delay().await;
2917            let channels = self.channels.lock();
2918            let memberships = self.channel_memberships.lock();
2919            Ok(channels
2920                .values()
2921                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2922                .cloned()
2923                .collect())
2924        }
2925
2926        async fn can_user_access_channel(
2927            &self,
2928            user_id: UserId,
2929            channel_id: ChannelId,
2930        ) -> Result<bool> {
2931            self.background.simulate_random_delay().await;
2932            Ok(self
2933                .channel_memberships
2934                .lock()
2935                .contains_key(&(channel_id, user_id)))
2936        }
2937
2938        async fn add_channel_member(
2939            &self,
2940            channel_id: ChannelId,
2941            user_id: UserId,
2942            is_admin: bool,
2943        ) -> Result<()> {
2944            self.background.simulate_random_delay().await;
2945            if !self.channels.lock().contains_key(&channel_id) {
2946                Err(anyhow!("channel does not exist"))?;
2947            }
2948            if !self.users.lock().contains_key(&user_id) {
2949                Err(anyhow!("user does not exist"))?;
2950            }
2951
2952            self.channel_memberships
2953                .lock()
2954                .entry((channel_id, user_id))
2955                .or_insert(is_admin);
2956            Ok(())
2957        }
2958
2959        async fn create_channel_message(
2960            &self,
2961            channel_id: ChannelId,
2962            sender_id: UserId,
2963            body: &str,
2964            timestamp: OffsetDateTime,
2965            nonce: u128,
2966        ) -> Result<MessageId> {
2967            self.background.simulate_random_delay().await;
2968            if !self.channels.lock().contains_key(&channel_id) {
2969                Err(anyhow!("channel does not exist"))?;
2970            }
2971            if !self.users.lock().contains_key(&sender_id) {
2972                Err(anyhow!("user does not exist"))?;
2973            }
2974
2975            let mut messages = self.channel_messages.lock();
2976            if let Some(message) = messages
2977                .values()
2978                .find(|message| message.nonce.as_u128() == nonce)
2979            {
2980                Ok(message.id)
2981            } else {
2982                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2983                messages.insert(
2984                    message_id,
2985                    ChannelMessage {
2986                        id: message_id,
2987                        channel_id,
2988                        sender_id,
2989                        body: body.to_string(),
2990                        sent_at: timestamp,
2991                        nonce: Uuid::from_u128(nonce),
2992                    },
2993                );
2994                Ok(message_id)
2995            }
2996        }
2997
2998        async fn get_channel_messages(
2999            &self,
3000            channel_id: ChannelId,
3001            count: usize,
3002            before_id: Option<MessageId>,
3003        ) -> Result<Vec<ChannelMessage>> {
3004            self.background.simulate_random_delay().await;
3005            let mut messages = self
3006                .channel_messages
3007                .lock()
3008                .values()
3009                .rev()
3010                .filter(|message| {
3011                    message.channel_id == channel_id
3012                        && message.id < before_id.unwrap_or(MessageId::MAX)
3013                })
3014                .take(count)
3015                .cloned()
3016                .collect::<Vec<_>>();
3017            messages.sort_unstable_by_key(|message| message.id);
3018            Ok(messages)
3019        }
3020
3021        async fn teardown(&self, _: &str) {}
3022
3023        #[cfg(test)]
3024        fn as_fake(&self) -> Option<&FakeDb> {
3025            Some(self)
3026        }
3027    }
3028
3029    fn build_background_executor() -> Arc<Background> {
3030        Deterministic::new(0).build_background()
3031    }
3032}