db.rs

   1use std::{cmp, ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use serde::{Deserialize, Serialize};
  10pub use sqlx::postgres::PgPoolOptions as DbOptions;
  11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  12use time::{OffsetDateTime, PrimitiveDateTime};
  13
  14#[async_trait]
  15pub trait Db: Send + Sync {
  16    async fn create_user(
  17        &self,
  18        github_login: &str,
  19        email_address: Option<&str>,
  20        admin: bool,
  21    ) -> Result<UserId>;
  22    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  23    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  24    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  25    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  26    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  27    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  28    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  29    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  30    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  31    async fn destroy_user(&self, id: UserId) -> Result<()>;
  32
  33    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  34    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  35    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  36    async fn redeem_invite_code(
  37        &self,
  38        code: &str,
  39        login: &str,
  40        email_address: Option<&str>,
  41    ) -> Result<UserId>;
  42
  43    /// Registers a new project for the given user.
  44    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  45
  46    /// Unregisters a project for the given project id.
  47    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  48
  49    /// Update file counts by extension for the given project and worktree.
  50    async fn update_worktree_extensions(
  51        &self,
  52        project_id: ProjectId,
  53        worktree_id: u64,
  54        extensions: HashMap<String, u32>,
  55    ) -> Result<()>;
  56
  57    /// Get the file counts on the given project keyed by their worktree and extension.
  58    async fn get_project_extensions(
  59        &self,
  60        project_id: ProjectId,
  61    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  62
  63    /// Record which users have been active in which projects during
  64    /// a given period of time.
  65    async fn record_user_activity(
  66        &self,
  67        time_period: Range<OffsetDateTime>,
  68        active_projects: &[(UserId, ProjectId)],
  69    ) -> Result<()>;
  70
  71    /// Get the number of users who have been active in the given
  72    /// time period for at least the given time duration.
  73    async fn get_active_user_count(
  74        &self,
  75        time_period: Range<OffsetDateTime>,
  76        min_duration: Duration,
  77        only_collaborative: bool,
  78    ) -> Result<usize>;
  79
  80    /// Get the users that have been most active during the given time period,
  81    /// along with the amount of time they have been active in each project.
  82    async fn get_top_users_activity_summary(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        max_user_count: usize,
  86    ) -> Result<Vec<UserActivitySummary>>;
  87
  88    /// Get the project activity for the given user and time period.
  89    async fn get_user_activity_timeline(
  90        &self,
  91        time_period: Range<OffsetDateTime>,
  92        user_id: UserId,
  93    ) -> Result<Vec<UserActivityPeriod>>;
  94
  95    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  96    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  97    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  98    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  99    async fn dismiss_contact_notification(
 100        &self,
 101        responder_id: UserId,
 102        requester_id: UserId,
 103    ) -> Result<()>;
 104    async fn respond_to_contact_request(
 105        &self,
 106        responder_id: UserId,
 107        requester_id: UserId,
 108        accept: bool,
 109    ) -> Result<()>;
 110
 111    async fn create_access_token_hash(
 112        &self,
 113        user_id: UserId,
 114        access_token_hash: &str,
 115        max_access_token_count: usize,
 116    ) -> Result<()>;
 117    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 118    #[cfg(any(test, feature = "seed-support"))]
 119
 120    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 121    #[cfg(any(test, feature = "seed-support"))]
 122    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 123    #[cfg(any(test, feature = "seed-support"))]
 124    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 127    #[cfg(any(test, feature = "seed-support"))]
 128
 129    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 130    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 131    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 132        -> Result<bool>;
 133    #[cfg(any(test, feature = "seed-support"))]
 134    async fn add_channel_member(
 135        &self,
 136        channel_id: ChannelId,
 137        user_id: UserId,
 138        is_admin: bool,
 139    ) -> Result<()>;
 140    async fn create_channel_message(
 141        &self,
 142        channel_id: ChannelId,
 143        sender_id: UserId,
 144        body: &str,
 145        timestamp: OffsetDateTime,
 146        nonce: u128,
 147    ) -> Result<MessageId>;
 148    async fn get_channel_messages(
 149        &self,
 150        channel_id: ChannelId,
 151        count: usize,
 152        before_id: Option<MessageId>,
 153    ) -> Result<Vec<ChannelMessage>>;
 154    #[cfg(test)]
 155    async fn teardown(&self, url: &str);
 156    #[cfg(test)]
 157    fn as_fake(&self) -> Option<&tests::FakeDb>;
 158}
 159
 160pub struct PostgresDb {
 161    pool: sqlx::PgPool,
 162}
 163
 164impl PostgresDb {
 165    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 166        let pool = DbOptions::new()
 167            .max_connections(max_connections)
 168            .connect(url)
 169            .await
 170            .context("failed to connect to postgres database")?;
 171        Ok(Self { pool })
 172    }
 173}
 174
 175#[async_trait]
 176impl Db for PostgresDb {
 177    // users
 178
 179    async fn create_user(
 180        &self,
 181        github_login: &str,
 182        email_address: Option<&str>,
 183        admin: bool,
 184    ) -> Result<UserId> {
 185        let query = "
 186            INSERT INTO users (github_login, email_address, admin)
 187            VALUES ($1, $2, $3)
 188            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 189            RETURNING id
 190        ";
 191        Ok(sqlx::query_scalar(query)
 192            .bind(github_login)
 193            .bind(email_address)
 194            .bind(admin)
 195            .fetch_one(&self.pool)
 196            .await
 197            .map(UserId)?)
 198    }
 199
 200    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 201        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 202        Ok(sqlx::query_as(query)
 203            .bind(limit as i32)
 204            .bind((page * limit) as i32)
 205            .fetch_all(&self.pool)
 206            .await?)
 207    }
 208
 209    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 210        let mut query = QueryBuilder::new(
 211            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 212        );
 213        query.push_values(
 214            users,
 215            |mut query, (github_login, email_address, invite_count)| {
 216                query
 217                    .push_bind(github_login)
 218                    .push_bind(email_address)
 219                    .push_bind(false)
 220                    .push_bind(random_invite_code())
 221                    .push_bind(invite_count as i32);
 222            },
 223        );
 224        query.push(
 225            "
 226            ON CONFLICT (github_login) DO UPDATE SET
 227                github_login = excluded.github_login,
 228                invite_count = excluded.invite_count,
 229                invite_code = CASE WHEN users.invite_code IS NULL
 230                                   THEN excluded.invite_code
 231                                   ELSE users.invite_code
 232                              END
 233            RETURNING id
 234            ",
 235        );
 236
 237        let rows = query.build().fetch_all(&self.pool).await?;
 238        Ok(rows
 239            .into_iter()
 240            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 241            .collect())
 242    }
 243
 244    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 245        let like_string = fuzzy_like_string(name_query);
 246        let query = "
 247            SELECT users.*
 248            FROM users
 249            WHERE github_login ILIKE $1
 250            ORDER BY github_login <-> $2
 251            LIMIT $3
 252        ";
 253        Ok(sqlx::query_as(query)
 254            .bind(like_string)
 255            .bind(name_query)
 256            .bind(limit as i32)
 257            .fetch_all(&self.pool)
 258            .await?)
 259    }
 260
 261    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 262        let users = self.get_users_by_ids(vec![id]).await?;
 263        Ok(users.into_iter().next())
 264    }
 265
 266    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 267        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 268        let query = "
 269            SELECT users.*
 270            FROM users
 271            WHERE users.id = ANY ($1)
 272        ";
 273        Ok(sqlx::query_as(query)
 274            .bind(&ids)
 275            .fetch_all(&self.pool)
 276            .await?)
 277    }
 278
 279    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 280        let query = format!(
 281            "
 282            SELECT users.*
 283            FROM users
 284            WHERE invite_count = 0
 285            AND inviter_id IS{} NULL
 286            ",
 287            if invited_by_another_user { " NOT" } else { "" }
 288        );
 289
 290        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 291    }
 292
 293    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 294        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 295        Ok(sqlx::query_as(query)
 296            .bind(github_login)
 297            .fetch_optional(&self.pool)
 298            .await?)
 299    }
 300
 301    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 302        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 303        Ok(sqlx::query(query)
 304            .bind(is_admin)
 305            .bind(id.0)
 306            .execute(&self.pool)
 307            .await
 308            .map(drop)?)
 309    }
 310
 311    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 312        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 313        Ok(sqlx::query(query)
 314            .bind(connected_once)
 315            .bind(id.0)
 316            .execute(&self.pool)
 317            .await
 318            .map(drop)?)
 319    }
 320
 321    async fn destroy_user(&self, id: UserId) -> Result<()> {
 322        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 323        sqlx::query(query)
 324            .bind(id.0)
 325            .execute(&self.pool)
 326            .await
 327            .map(drop)?;
 328        let query = "DELETE FROM users WHERE id = $1;";
 329        Ok(sqlx::query(query)
 330            .bind(id.0)
 331            .execute(&self.pool)
 332            .await
 333            .map(drop)?)
 334    }
 335
 336    // invite codes
 337
 338    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 339        let mut tx = self.pool.begin().await?;
 340        if count > 0 {
 341            sqlx::query(
 342                "
 343                UPDATE users
 344                SET invite_code = $1
 345                WHERE id = $2 AND invite_code IS NULL
 346            ",
 347            )
 348            .bind(random_invite_code())
 349            .bind(id)
 350            .execute(&mut tx)
 351            .await?;
 352        }
 353
 354        sqlx::query(
 355            "
 356            UPDATE users
 357            SET invite_count = $1
 358            WHERE id = $2
 359            ",
 360        )
 361        .bind(count as i32)
 362        .bind(id)
 363        .execute(&mut tx)
 364        .await?;
 365        tx.commit().await?;
 366        Ok(())
 367    }
 368
 369    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 370        let result: Option<(String, i32)> = sqlx::query_as(
 371            "
 372                SELECT invite_code, invite_count
 373                FROM users
 374                WHERE id = $1 AND invite_code IS NOT NULL 
 375            ",
 376        )
 377        .bind(id)
 378        .fetch_optional(&self.pool)
 379        .await?;
 380        if let Some((code, count)) = result {
 381            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 382        } else {
 383            Ok(None)
 384        }
 385    }
 386
 387    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 388        sqlx::query_as(
 389            "
 390                SELECT *
 391                FROM users
 392                WHERE invite_code = $1
 393            ",
 394        )
 395        .bind(code)
 396        .fetch_optional(&self.pool)
 397        .await?
 398        .ok_or_else(|| {
 399            Error::Http(
 400                StatusCode::NOT_FOUND,
 401                "that invite code does not exist".to_string(),
 402            )
 403        })
 404    }
 405
 406    async fn redeem_invite_code(
 407        &self,
 408        code: &str,
 409        login: &str,
 410        email_address: Option<&str>,
 411    ) -> Result<UserId> {
 412        let mut tx = self.pool.begin().await?;
 413
 414        let inviter_id: Option<UserId> = sqlx::query_scalar(
 415            "
 416                UPDATE users
 417                SET invite_count = invite_count - 1
 418                WHERE
 419                    invite_code = $1 AND
 420                    invite_count > 0
 421                RETURNING id
 422            ",
 423        )
 424        .bind(code)
 425        .fetch_optional(&mut tx)
 426        .await?;
 427
 428        let inviter_id = match inviter_id {
 429            Some(inviter_id) => inviter_id,
 430            None => {
 431                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 432                    .bind(code)
 433                    .fetch_optional(&mut tx)
 434                    .await?
 435                    .is_some()
 436                {
 437                    Err(Error::Http(
 438                        StatusCode::UNAUTHORIZED,
 439                        "no invites remaining".to_string(),
 440                    ))?
 441                } else {
 442                    Err(Error::Http(
 443                        StatusCode::NOT_FOUND,
 444                        "invite code not found".to_string(),
 445                    ))?
 446                }
 447            }
 448        };
 449
 450        let invitee_id = sqlx::query_scalar(
 451            "
 452                INSERT INTO users
 453                    (github_login, email_address, admin, inviter_id, invite_code, invite_count)
 454                VALUES
 455                    ($1, $2, 'f', $3, $4, $5)
 456                RETURNING id
 457            ",
 458        )
 459        .bind(login)
 460        .bind(email_address)
 461        .bind(inviter_id)
 462        .bind(random_invite_code())
 463        .bind(5)
 464        .fetch_one(&mut tx)
 465        .await
 466        .map(UserId)?;
 467
 468        sqlx::query(
 469            "
 470                INSERT INTO contacts
 471                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 472                VALUES
 473                    ($1, $2, 't', 't', 't')
 474            ",
 475        )
 476        .bind(inviter_id)
 477        .bind(invitee_id)
 478        .execute(&mut tx)
 479        .await?;
 480
 481        tx.commit().await?;
 482        Ok(invitee_id)
 483    }
 484
 485    // projects
 486
 487    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 488        Ok(sqlx::query_scalar(
 489            "
 490            INSERT INTO projects(host_user_id)
 491            VALUES ($1)
 492            RETURNING id
 493            ",
 494        )
 495        .bind(host_user_id)
 496        .fetch_one(&self.pool)
 497        .await
 498        .map(ProjectId)?)
 499    }
 500
 501    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 502        sqlx::query(
 503            "
 504            UPDATE projects
 505            SET unregistered = 't'
 506            WHERE id = $1
 507            ",
 508        )
 509        .bind(project_id)
 510        .execute(&self.pool)
 511        .await?;
 512        Ok(())
 513    }
 514
 515    async fn update_worktree_extensions(
 516        &self,
 517        project_id: ProjectId,
 518        worktree_id: u64,
 519        extensions: HashMap<String, u32>,
 520    ) -> Result<()> {
 521        if extensions.is_empty() {
 522            return Ok(());
 523        }
 524
 525        let mut query = QueryBuilder::new(
 526            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 527        );
 528        query.push_values(extensions, |mut query, (extension, count)| {
 529            query
 530                .push_bind(project_id)
 531                .push_bind(worktree_id as i32)
 532                .push_bind(extension)
 533                .push_bind(count as i32);
 534        });
 535        query.push(
 536            "
 537            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 538            count = excluded.count
 539            ",
 540        );
 541        query.build().execute(&self.pool).await?;
 542
 543        Ok(())
 544    }
 545
 546    async fn get_project_extensions(
 547        &self,
 548        project_id: ProjectId,
 549    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 550        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 551        struct WorktreeExtension {
 552            worktree_id: i32,
 553            extension: String,
 554            count: i32,
 555        }
 556
 557        let query = "
 558            SELECT worktree_id, extension, count
 559            FROM worktree_extensions
 560            WHERE project_id = $1
 561        ";
 562        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 563            .bind(&project_id)
 564            .fetch_all(&self.pool)
 565            .await?;
 566
 567        let mut extension_counts = HashMap::default();
 568        for count in counts {
 569            extension_counts
 570                .entry(count.worktree_id as u64)
 571                .or_insert_with(HashMap::default)
 572                .insert(count.extension, count.count as usize);
 573        }
 574        Ok(extension_counts)
 575    }
 576
 577    async fn record_user_activity(
 578        &self,
 579        time_period: Range<OffsetDateTime>,
 580        projects: &[(UserId, ProjectId)],
 581    ) -> Result<()> {
 582        let query = "
 583            INSERT INTO project_activity_periods
 584            (ended_at, duration_millis, user_id, project_id)
 585            VALUES
 586            ($1, $2, $3, $4);
 587        ";
 588
 589        let mut tx = self.pool.begin().await?;
 590        let duration_millis =
 591            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 592        for (user_id, project_id) in projects {
 593            sqlx::query(query)
 594                .bind(time_period.end)
 595                .bind(duration_millis)
 596                .bind(user_id)
 597                .bind(project_id)
 598                .execute(&mut tx)
 599                .await?;
 600        }
 601        tx.commit().await?;
 602
 603        Ok(())
 604    }
 605
 606    async fn get_active_user_count(
 607        &self,
 608        time_period: Range<OffsetDateTime>,
 609        min_duration: Duration,
 610        only_collaborative: bool,
 611    ) -> Result<usize> {
 612        let mut with_clause = String::new();
 613        with_clause.push_str("WITH\n");
 614        with_clause.push_str(
 615            "
 616            project_durations AS (
 617                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 618                FROM project_activity_periods
 619                WHERE $1 < ended_at AND ended_at <= $2
 620                GROUP BY user_id, project_id
 621            ),
 622            ",
 623        );
 624        with_clause.push_str(
 625            "
 626            project_collaborators as (
 627                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 628                FROM project_durations
 629                GROUP BY project_id
 630            ),
 631            ",
 632        );
 633
 634        if only_collaborative {
 635            with_clause.push_str(
 636                "
 637                user_durations AS (
 638                    SELECT user_id, SUM(project_duration) as total_duration
 639                    FROM project_durations, project_collaborators
 640                    WHERE
 641                        project_durations.project_id = project_collaborators.project_id AND
 642                        max_collaborators > 1
 643                    GROUP BY user_id
 644                    ORDER BY total_duration DESC
 645                    LIMIT $3
 646                )
 647                ",
 648            );
 649        } else {
 650            with_clause.push_str(
 651                "
 652                user_durations AS (
 653                    SELECT user_id, SUM(project_duration) as total_duration
 654                    FROM project_durations
 655                    GROUP BY user_id
 656                    ORDER BY total_duration DESC
 657                    LIMIT $3
 658                )
 659                ",
 660            );
 661        }
 662
 663        let query = format!(
 664            "
 665            {with_clause}
 666            SELECT count(user_durations.user_id)
 667            FROM user_durations
 668            WHERE user_durations.total_duration >= $3
 669            "
 670        );
 671
 672        let count: i64 = sqlx::query_scalar(&query)
 673            .bind(time_period.start)
 674            .bind(time_period.end)
 675            .bind(min_duration.as_millis() as i64)
 676            .fetch_one(&self.pool)
 677            .await?;
 678        Ok(count as usize)
 679    }
 680
 681    async fn get_top_users_activity_summary(
 682        &self,
 683        time_period: Range<OffsetDateTime>,
 684        max_user_count: usize,
 685    ) -> Result<Vec<UserActivitySummary>> {
 686        let query = "
 687            WITH
 688                project_durations AS (
 689                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 690                    FROM project_activity_periods
 691                    WHERE $1 < ended_at AND ended_at <= $2
 692                    GROUP BY user_id, project_id
 693                ),
 694                user_durations AS (
 695                    SELECT user_id, SUM(project_duration) as total_duration
 696                    FROM project_durations
 697                    GROUP BY user_id
 698                    ORDER BY total_duration DESC
 699                    LIMIT $3
 700                ),
 701                project_collaborators as (
 702                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 703                    FROM project_durations
 704                    GROUP BY project_id
 705                )
 706            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
 707            FROM user_durations, project_durations, project_collaborators, users
 708            WHERE
 709                user_durations.user_id = project_durations.user_id AND
 710                user_durations.user_id = users.id AND
 711                project_durations.project_id = project_collaborators.project_id
 712            ORDER BY total_duration DESC, user_id ASC, project_id ASC
 713        ";
 714
 715        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
 716            .bind(time_period.start)
 717            .bind(time_period.end)
 718            .bind(max_user_count as i32)
 719            .fetch(&self.pool);
 720
 721        let mut result = Vec::<UserActivitySummary>::new();
 722        while let Some(row) = rows.next().await {
 723            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
 724            let project_id = project_id;
 725            let duration = Duration::from_millis(duration_millis as u64);
 726            let project_activity = ProjectActivitySummary {
 727                id: project_id,
 728                duration,
 729                max_collaborators: project_collaborators as usize,
 730            };
 731            if let Some(last_summary) = result.last_mut() {
 732                if last_summary.id == user_id {
 733                    last_summary.project_activity.push(project_activity);
 734                    continue;
 735                }
 736            }
 737            result.push(UserActivitySummary {
 738                id: user_id,
 739                project_activity: vec![project_activity],
 740                github_login,
 741            });
 742        }
 743
 744        Ok(result)
 745    }
 746
 747    async fn get_user_activity_timeline(
 748        &self,
 749        time_period: Range<OffsetDateTime>,
 750        user_id: UserId,
 751    ) -> Result<Vec<UserActivityPeriod>> {
 752        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 753
 754        let query = "
 755            SELECT
 756                project_activity_periods.ended_at,
 757                project_activity_periods.duration_millis,
 758                project_activity_periods.project_id,
 759                worktree_extensions.extension,
 760                worktree_extensions.count
 761            FROM project_activity_periods
 762            LEFT OUTER JOIN
 763                worktree_extensions
 764            ON
 765                project_activity_periods.project_id = worktree_extensions.project_id
 766            WHERE
 767                project_activity_periods.user_id = $1 AND
 768                $2 < project_activity_periods.ended_at AND
 769                project_activity_periods.ended_at <= $3
 770            ORDER BY project_activity_periods.id ASC
 771        ";
 772
 773        let mut rows = sqlx::query_as::<
 774            _,
 775            (
 776                PrimitiveDateTime,
 777                i32,
 778                ProjectId,
 779                Option<String>,
 780                Option<i32>,
 781            ),
 782        >(query)
 783        .bind(user_id)
 784        .bind(time_period.start)
 785        .bind(time_period.end)
 786        .fetch(&self.pool);
 787
 788        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 789        while let Some(row) = rows.next().await {
 790            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 791            let ended_at = ended_at.assume_utc();
 792            let duration = Duration::from_millis(duration_millis as u64);
 793            let started_at = ended_at - duration;
 794            let project_time_periods = time_periods.entry(project_id).or_default();
 795
 796            if let Some(prev_duration) = project_time_periods.last_mut() {
 797                if started_at <= prev_duration.end + COALESCE_THRESHOLD
 798                    && ended_at >= prev_duration.start
 799                {
 800                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
 801                } else {
 802                    project_time_periods.push(UserActivityPeriod {
 803                        project_id,
 804                        start: started_at,
 805                        end: ended_at,
 806                        extensions: Default::default(),
 807                    });
 808                }
 809            } else {
 810                project_time_periods.push(UserActivityPeriod {
 811                    project_id,
 812                    start: started_at,
 813                    end: ended_at,
 814                    extensions: Default::default(),
 815                });
 816            }
 817
 818            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 819                project_time_periods
 820                    .last_mut()
 821                    .unwrap()
 822                    .extensions
 823                    .insert(extension, extension_count as usize);
 824            }
 825        }
 826
 827        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 828        durations.sort_unstable_by_key(|duration| duration.start);
 829        Ok(durations)
 830    }
 831
 832    // contacts
 833
 834    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 835        let query = "
 836            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 837            FROM contacts
 838            WHERE user_id_a = $1 OR user_id_b = $1;
 839        ";
 840
 841        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 842            .bind(user_id)
 843            .fetch(&self.pool);
 844
 845        let mut contacts = vec![Contact::Accepted {
 846            user_id,
 847            should_notify: false,
 848        }];
 849        while let Some(row) = rows.next().await {
 850            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 851
 852            if user_id_a == user_id {
 853                if accepted {
 854                    contacts.push(Contact::Accepted {
 855                        user_id: user_id_b,
 856                        should_notify: should_notify && a_to_b,
 857                    });
 858                } else if a_to_b {
 859                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 860                } else {
 861                    contacts.push(Contact::Incoming {
 862                        user_id: user_id_b,
 863                        should_notify,
 864                    });
 865                }
 866            } else if accepted {
 867                contacts.push(Contact::Accepted {
 868                    user_id: user_id_a,
 869                    should_notify: should_notify && !a_to_b,
 870                });
 871            } else if a_to_b {
 872                contacts.push(Contact::Incoming {
 873                    user_id: user_id_a,
 874                    should_notify,
 875                });
 876            } else {
 877                contacts.push(Contact::Outgoing { user_id: user_id_a });
 878            }
 879        }
 880
 881        contacts.sort_unstable_by_key(|contact| contact.user_id());
 882
 883        Ok(contacts)
 884    }
 885
 886    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 887        let (id_a, id_b) = if user_id_1 < user_id_2 {
 888            (user_id_1, user_id_2)
 889        } else {
 890            (user_id_2, user_id_1)
 891        };
 892
 893        let query = "
 894            SELECT 1 FROM contacts
 895            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 896            LIMIT 1
 897        ";
 898        Ok(sqlx::query_scalar::<_, i32>(query)
 899            .bind(id_a.0)
 900            .bind(id_b.0)
 901            .fetch_optional(&self.pool)
 902            .await?
 903            .is_some())
 904    }
 905
 906    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 907        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 908            (sender_id, receiver_id, true)
 909        } else {
 910            (receiver_id, sender_id, false)
 911        };
 912        let query = "
 913            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 914            VALUES ($1, $2, $3, 'f', 't')
 915            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 916            SET
 917                accepted = 't',
 918                should_notify = 'f'
 919            WHERE
 920                NOT contacts.accepted AND
 921                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 922                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 923        ";
 924        let result = sqlx::query(query)
 925            .bind(id_a.0)
 926            .bind(id_b.0)
 927            .bind(a_to_b)
 928            .execute(&self.pool)
 929            .await?;
 930
 931        if result.rows_affected() == 1 {
 932            Ok(())
 933        } else {
 934            Err(anyhow!("contact already requested"))?
 935        }
 936    }
 937
 938    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 939        let (id_a, id_b) = if responder_id < requester_id {
 940            (responder_id, requester_id)
 941        } else {
 942            (requester_id, responder_id)
 943        };
 944        let query = "
 945            DELETE FROM contacts
 946            WHERE user_id_a = $1 AND user_id_b = $2;
 947        ";
 948        let result = sqlx::query(query)
 949            .bind(id_a.0)
 950            .bind(id_b.0)
 951            .execute(&self.pool)
 952            .await?;
 953
 954        if result.rows_affected() == 1 {
 955            Ok(())
 956        } else {
 957            Err(anyhow!("no such contact"))?
 958        }
 959    }
 960
 961    async fn dismiss_contact_notification(
 962        &self,
 963        user_id: UserId,
 964        contact_user_id: UserId,
 965    ) -> Result<()> {
 966        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 967            (user_id, contact_user_id, true)
 968        } else {
 969            (contact_user_id, user_id, false)
 970        };
 971
 972        let query = "
 973            UPDATE contacts
 974            SET should_notify = 'f'
 975            WHERE
 976                user_id_a = $1 AND user_id_b = $2 AND
 977                (
 978                    (a_to_b = $3 AND accepted) OR
 979                    (a_to_b != $3 AND NOT accepted)
 980                );
 981        ";
 982
 983        let result = sqlx::query(query)
 984            .bind(id_a.0)
 985            .bind(id_b.0)
 986            .bind(a_to_b)
 987            .execute(&self.pool)
 988            .await?;
 989
 990        if result.rows_affected() == 0 {
 991            Err(anyhow!("no such contact request"))?;
 992        }
 993
 994        Ok(())
 995    }
 996
 997    async fn respond_to_contact_request(
 998        &self,
 999        responder_id: UserId,
1000        requester_id: UserId,
1001        accept: bool,
1002    ) -> Result<()> {
1003        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1004            (responder_id, requester_id, false)
1005        } else {
1006            (requester_id, responder_id, true)
1007        };
1008        let result = if accept {
1009            let query = "
1010                UPDATE contacts
1011                SET accepted = 't', should_notify = 't'
1012                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1013            ";
1014            sqlx::query(query)
1015                .bind(id_a.0)
1016                .bind(id_b.0)
1017                .bind(a_to_b)
1018                .execute(&self.pool)
1019                .await?
1020        } else {
1021            let query = "
1022                DELETE FROM contacts
1023                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1024            ";
1025            sqlx::query(query)
1026                .bind(id_a.0)
1027                .bind(id_b.0)
1028                .bind(a_to_b)
1029                .execute(&self.pool)
1030                .await?
1031        };
1032        if result.rows_affected() == 1 {
1033            Ok(())
1034        } else {
1035            Err(anyhow!("no such contact request"))?
1036        }
1037    }
1038
1039    // access tokens
1040
1041    async fn create_access_token_hash(
1042        &self,
1043        user_id: UserId,
1044        access_token_hash: &str,
1045        max_access_token_count: usize,
1046    ) -> Result<()> {
1047        let insert_query = "
1048            INSERT INTO access_tokens (user_id, hash)
1049            VALUES ($1, $2);
1050        ";
1051        let cleanup_query = "
1052            DELETE FROM access_tokens
1053            WHERE id IN (
1054                SELECT id from access_tokens
1055                WHERE user_id = $1
1056                ORDER BY id DESC
1057                OFFSET $3
1058            )
1059        ";
1060
1061        let mut tx = self.pool.begin().await?;
1062        sqlx::query(insert_query)
1063            .bind(user_id.0)
1064            .bind(access_token_hash)
1065            .execute(&mut tx)
1066            .await?;
1067        sqlx::query(cleanup_query)
1068            .bind(user_id.0)
1069            .bind(access_token_hash)
1070            .bind(max_access_token_count as i32)
1071            .execute(&mut tx)
1072            .await?;
1073        Ok(tx.commit().await?)
1074    }
1075
1076    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1077        let query = "
1078            SELECT hash
1079            FROM access_tokens
1080            WHERE user_id = $1
1081            ORDER BY id DESC
1082        ";
1083        Ok(sqlx::query_scalar(query)
1084            .bind(user_id.0)
1085            .fetch_all(&self.pool)
1086            .await?)
1087    }
1088
1089    // orgs
1090
1091    #[allow(unused)] // Help rust-analyzer
1092    #[cfg(any(test, feature = "seed-support"))]
1093    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1094        let query = "
1095            SELECT *
1096            FROM orgs
1097            WHERE slug = $1
1098        ";
1099        Ok(sqlx::query_as(query)
1100            .bind(slug)
1101            .fetch_optional(&self.pool)
1102            .await?)
1103    }
1104
1105    #[cfg(any(test, feature = "seed-support"))]
1106    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1107        let query = "
1108            INSERT INTO orgs (name, slug)
1109            VALUES ($1, $2)
1110            RETURNING id
1111        ";
1112        Ok(sqlx::query_scalar(query)
1113            .bind(name)
1114            .bind(slug)
1115            .fetch_one(&self.pool)
1116            .await
1117            .map(OrgId)?)
1118    }
1119
1120    #[cfg(any(test, feature = "seed-support"))]
1121    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1122        let query = "
1123            INSERT INTO org_memberships (org_id, user_id, admin)
1124            VALUES ($1, $2, $3)
1125            ON CONFLICT DO NOTHING
1126        ";
1127        Ok(sqlx::query(query)
1128            .bind(org_id.0)
1129            .bind(user_id.0)
1130            .bind(is_admin)
1131            .execute(&self.pool)
1132            .await
1133            .map(drop)?)
1134    }
1135
1136    // channels
1137
1138    #[cfg(any(test, feature = "seed-support"))]
1139    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1140        let query = "
1141            INSERT INTO channels (owner_id, owner_is_user, name)
1142            VALUES ($1, false, $2)
1143            RETURNING id
1144        ";
1145        Ok(sqlx::query_scalar(query)
1146            .bind(org_id.0)
1147            .bind(name)
1148            .fetch_one(&self.pool)
1149            .await
1150            .map(ChannelId)?)
1151    }
1152
1153    #[allow(unused)] // Help rust-analyzer
1154    #[cfg(any(test, feature = "seed-support"))]
1155    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1156        let query = "
1157            SELECT *
1158            FROM channels
1159            WHERE
1160                channels.owner_is_user = false AND
1161                channels.owner_id = $1
1162        ";
1163        Ok(sqlx::query_as(query)
1164            .bind(org_id.0)
1165            .fetch_all(&self.pool)
1166            .await?)
1167    }
1168
1169    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1170        let query = "
1171            SELECT
1172                channels.*
1173            FROM
1174                channel_memberships, channels
1175            WHERE
1176                channel_memberships.user_id = $1 AND
1177                channel_memberships.channel_id = channels.id
1178        ";
1179        Ok(sqlx::query_as(query)
1180            .bind(user_id.0)
1181            .fetch_all(&self.pool)
1182            .await?)
1183    }
1184
1185    async fn can_user_access_channel(
1186        &self,
1187        user_id: UserId,
1188        channel_id: ChannelId,
1189    ) -> Result<bool> {
1190        let query = "
1191            SELECT id
1192            FROM channel_memberships
1193            WHERE user_id = $1 AND channel_id = $2
1194            LIMIT 1
1195        ";
1196        Ok(sqlx::query_scalar::<_, i32>(query)
1197            .bind(user_id.0)
1198            .bind(channel_id.0)
1199            .fetch_optional(&self.pool)
1200            .await
1201            .map(|e| e.is_some())?)
1202    }
1203
1204    #[cfg(any(test, feature = "seed-support"))]
1205    async fn add_channel_member(
1206        &self,
1207        channel_id: ChannelId,
1208        user_id: UserId,
1209        is_admin: bool,
1210    ) -> Result<()> {
1211        let query = "
1212            INSERT INTO channel_memberships (channel_id, user_id, admin)
1213            VALUES ($1, $2, $3)
1214            ON CONFLICT DO NOTHING
1215        ";
1216        Ok(sqlx::query(query)
1217            .bind(channel_id.0)
1218            .bind(user_id.0)
1219            .bind(is_admin)
1220            .execute(&self.pool)
1221            .await
1222            .map(drop)?)
1223    }
1224
1225    // messages
1226
1227    async fn create_channel_message(
1228        &self,
1229        channel_id: ChannelId,
1230        sender_id: UserId,
1231        body: &str,
1232        timestamp: OffsetDateTime,
1233        nonce: u128,
1234    ) -> Result<MessageId> {
1235        let query = "
1236            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1237            VALUES ($1, $2, $3, $4, $5)
1238            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1239            RETURNING id
1240        ";
1241        Ok(sqlx::query_scalar(query)
1242            .bind(channel_id.0)
1243            .bind(sender_id.0)
1244            .bind(body)
1245            .bind(timestamp)
1246            .bind(Uuid::from_u128(nonce))
1247            .fetch_one(&self.pool)
1248            .await
1249            .map(MessageId)?)
1250    }
1251
1252    async fn get_channel_messages(
1253        &self,
1254        channel_id: ChannelId,
1255        count: usize,
1256        before_id: Option<MessageId>,
1257    ) -> Result<Vec<ChannelMessage>> {
1258        let query = r#"
1259            SELECT * FROM (
1260                SELECT
1261                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1262                FROM
1263                    channel_messages
1264                WHERE
1265                    channel_id = $1 AND
1266                    id < $2
1267                ORDER BY id DESC
1268                LIMIT $3
1269            ) as recent_messages
1270            ORDER BY id ASC
1271        "#;
1272        Ok(sqlx::query_as(query)
1273            .bind(channel_id.0)
1274            .bind(before_id.unwrap_or(MessageId::MAX))
1275            .bind(count as i64)
1276            .fetch_all(&self.pool)
1277            .await?)
1278    }
1279
1280    #[cfg(test)]
1281    async fn teardown(&self, url: &str) {
1282        use util::ResultExt;
1283
1284        let query = "
1285            SELECT pg_terminate_backend(pg_stat_activity.pid)
1286            FROM pg_stat_activity
1287            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1288        ";
1289        sqlx::query(query).execute(&self.pool).await.log_err();
1290        self.pool.close().await;
1291        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1292            .await
1293            .log_err();
1294    }
1295
1296    #[cfg(test)]
1297    fn as_fake(&self) -> Option<&tests::FakeDb> {
1298        None
1299    }
1300}
1301
1302macro_rules! id_type {
1303    ($name:ident) => {
1304        #[derive(
1305            Clone,
1306            Copy,
1307            Debug,
1308            Default,
1309            PartialEq,
1310            Eq,
1311            PartialOrd,
1312            Ord,
1313            Hash,
1314            sqlx::Type,
1315            Serialize,
1316            Deserialize,
1317        )]
1318        #[sqlx(transparent)]
1319        #[serde(transparent)]
1320        pub struct $name(pub i32);
1321
1322        impl $name {
1323            #[allow(unused)]
1324            pub const MAX: Self = Self(i32::MAX);
1325
1326            #[allow(unused)]
1327            pub fn from_proto(value: u64) -> Self {
1328                Self(value as i32)
1329            }
1330
1331            #[allow(unused)]
1332            pub fn to_proto(self) -> u64 {
1333                self.0 as u64
1334            }
1335        }
1336
1337        impl std::fmt::Display for $name {
1338            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1339                self.0.fmt(f)
1340            }
1341        }
1342    };
1343}
1344
1345id_type!(UserId);
1346#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1347pub struct User {
1348    pub id: UserId,
1349    pub github_login: String,
1350    pub email_address: Option<String>,
1351    pub admin: bool,
1352    pub invite_code: Option<String>,
1353    pub invite_count: i32,
1354    pub connected_once: bool,
1355}
1356
1357id_type!(ProjectId);
1358#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1359pub struct Project {
1360    pub id: ProjectId,
1361    pub host_user_id: UserId,
1362    pub unregistered: bool,
1363}
1364
1365#[derive(Clone, Debug, PartialEq, Serialize)]
1366pub struct UserActivitySummary {
1367    pub id: UserId,
1368    pub github_login: String,
1369    pub project_activity: Vec<ProjectActivitySummary>,
1370}
1371
1372#[derive(Clone, Debug, PartialEq, Serialize)]
1373pub struct ProjectActivitySummary {
1374    id: ProjectId,
1375    duration: Duration,
1376    max_collaborators: usize,
1377}
1378
1379#[derive(Clone, Debug, PartialEq, Serialize)]
1380pub struct UserActivityPeriod {
1381    project_id: ProjectId,
1382    #[serde(with = "time::serde::iso8601")]
1383    start: OffsetDateTime,
1384    #[serde(with = "time::serde::iso8601")]
1385    end: OffsetDateTime,
1386    extensions: HashMap<String, usize>,
1387}
1388
1389id_type!(OrgId);
1390#[derive(FromRow)]
1391pub struct Org {
1392    pub id: OrgId,
1393    pub name: String,
1394    pub slug: String,
1395}
1396
1397id_type!(ChannelId);
1398#[derive(Clone, Debug, FromRow, Serialize)]
1399pub struct Channel {
1400    pub id: ChannelId,
1401    pub name: String,
1402    pub owner_id: i32,
1403    pub owner_is_user: bool,
1404}
1405
1406id_type!(MessageId);
1407#[derive(Clone, Debug, FromRow)]
1408pub struct ChannelMessage {
1409    pub id: MessageId,
1410    pub channel_id: ChannelId,
1411    pub sender_id: UserId,
1412    pub body: String,
1413    pub sent_at: OffsetDateTime,
1414    pub nonce: Uuid,
1415}
1416
1417#[derive(Clone, Debug, PartialEq, Eq)]
1418pub enum Contact {
1419    Accepted {
1420        user_id: UserId,
1421        should_notify: bool,
1422    },
1423    Outgoing {
1424        user_id: UserId,
1425    },
1426    Incoming {
1427        user_id: UserId,
1428        should_notify: bool,
1429    },
1430}
1431
1432impl Contact {
1433    pub fn user_id(&self) -> UserId {
1434        match self {
1435            Contact::Accepted { user_id, .. } => *user_id,
1436            Contact::Outgoing { user_id } => *user_id,
1437            Contact::Incoming { user_id, .. } => *user_id,
1438        }
1439    }
1440}
1441
1442#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1443pub struct IncomingContactRequest {
1444    pub requester_id: UserId,
1445    pub should_notify: bool,
1446}
1447
1448fn fuzzy_like_string(string: &str) -> String {
1449    let mut result = String::with_capacity(string.len() * 2 + 1);
1450    for c in string.chars() {
1451        if c.is_alphanumeric() {
1452            result.push('%');
1453            result.push(c);
1454        }
1455    }
1456    result.push('%');
1457    result
1458}
1459
1460fn random_invite_code() -> String {
1461    nanoid::nanoid!(16)
1462}
1463
1464#[cfg(test)]
1465pub mod tests {
1466    use super::*;
1467    use anyhow::anyhow;
1468    use collections::BTreeMap;
1469    use gpui::executor::{Background, Deterministic};
1470    use lazy_static::lazy_static;
1471    use parking_lot::Mutex;
1472    use rand::prelude::*;
1473    use sqlx::{
1474        migrate::{MigrateDatabase, Migrator},
1475        Postgres,
1476    };
1477    use std::{path::Path, sync::Arc};
1478    use util::post_inc;
1479
1480    #[tokio::test(flavor = "multi_thread")]
1481    async fn test_get_users_by_ids() {
1482        for test_db in [
1483            TestDb::postgres().await,
1484            TestDb::fake(build_background_executor()),
1485        ] {
1486            let db = test_db.db();
1487
1488            let user = db.create_user("user", None, false).await.unwrap();
1489            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1490            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1491            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1492
1493            assert_eq!(
1494                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1495                    .await
1496                    .unwrap(),
1497                vec![
1498                    User {
1499                        id: user,
1500                        github_login: "user".to_string(),
1501                        admin: false,
1502                        ..Default::default()
1503                    },
1504                    User {
1505                        id: friend1,
1506                        github_login: "friend-1".to_string(),
1507                        admin: false,
1508                        ..Default::default()
1509                    },
1510                    User {
1511                        id: friend2,
1512                        github_login: "friend-2".to_string(),
1513                        admin: false,
1514                        ..Default::default()
1515                    },
1516                    User {
1517                        id: friend3,
1518                        github_login: "friend-3".to_string(),
1519                        admin: false,
1520                        ..Default::default()
1521                    }
1522                ]
1523            );
1524        }
1525    }
1526
1527    #[tokio::test(flavor = "multi_thread")]
1528    async fn test_create_users() {
1529        let db = TestDb::postgres().await;
1530        let db = db.db();
1531
1532        // Create the first batch of users, ensuring invite counts are assigned
1533        // correctly and the respective invite codes are unique.
1534        let user_ids_batch_1 = db
1535            .create_users(vec![
1536                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1537                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1538                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1539            ])
1540            .await
1541            .unwrap();
1542        assert_eq!(user_ids_batch_1.len(), 3);
1543
1544        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1545        assert_eq!(users.len(), 3);
1546        assert_eq!(users[0].github_login, "user1");
1547        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1548        assert_eq!(users[0].invite_count, 5);
1549        assert_eq!(users[1].github_login, "user2");
1550        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1551        assert_eq!(users[1].invite_count, 4);
1552        assert_eq!(users[2].github_login, "user3");
1553        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1554        assert_eq!(users[2].invite_count, 3);
1555
1556        let invite_code_1 = users[0].invite_code.clone().unwrap();
1557        let invite_code_2 = users[1].invite_code.clone().unwrap();
1558        let invite_code_3 = users[2].invite_code.clone().unwrap();
1559        assert_ne!(invite_code_1, invite_code_2);
1560        assert_ne!(invite_code_1, invite_code_3);
1561        assert_ne!(invite_code_2, invite_code_3);
1562
1563        // Create the second batch of users and include a user that is already in the database, ensuring
1564        // the invite count for the existing user is updated without changing their invite code.
1565        let user_ids_batch_2 = db
1566            .create_users(vec![
1567                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1568                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1569            ])
1570            .await
1571            .unwrap();
1572        assert_eq!(user_ids_batch_2.len(), 2);
1573        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1574
1575        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1576        assert_eq!(users.len(), 2);
1577        assert_eq!(users[0].github_login, "user2");
1578        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1579        assert_eq!(users[0].invite_count, 10);
1580        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1581        assert_eq!(users[1].github_login, "user4");
1582        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1583        assert_eq!(users[1].invite_count, 2);
1584
1585        let invite_code_4 = users[1].invite_code.clone().unwrap();
1586        assert_ne!(invite_code_4, invite_code_1);
1587        assert_ne!(invite_code_4, invite_code_2);
1588        assert_ne!(invite_code_4, invite_code_3);
1589    }
1590
1591    #[tokio::test(flavor = "multi_thread")]
1592    async fn test_worktree_extensions() {
1593        let test_db = TestDb::postgres().await;
1594        let db = test_db.db();
1595
1596        let user = db.create_user("user_1", None, false).await.unwrap();
1597        let project = db.register_project(user).await.unwrap();
1598
1599        db.update_worktree_extensions(project, 100, Default::default())
1600            .await
1601            .unwrap();
1602        db.update_worktree_extensions(
1603            project,
1604            100,
1605            [("rs".to_string(), 5), ("md".to_string(), 3)]
1606                .into_iter()
1607                .collect(),
1608        )
1609        .await
1610        .unwrap();
1611        db.update_worktree_extensions(
1612            project,
1613            100,
1614            [("rs".to_string(), 6), ("md".to_string(), 5)]
1615                .into_iter()
1616                .collect(),
1617        )
1618        .await
1619        .unwrap();
1620        db.update_worktree_extensions(
1621            project,
1622            101,
1623            [("ts".to_string(), 2), ("md".to_string(), 1)]
1624                .into_iter()
1625                .collect(),
1626        )
1627        .await
1628        .unwrap();
1629
1630        assert_eq!(
1631            db.get_project_extensions(project).await.unwrap(),
1632            [
1633                (
1634                    100,
1635                    [("rs".into(), 6), ("md".into(), 5),]
1636                        .into_iter()
1637                        .collect::<HashMap<_, _>>()
1638                ),
1639                (
1640                    101,
1641                    [("ts".into(), 2), ("md".into(), 1),]
1642                        .into_iter()
1643                        .collect::<HashMap<_, _>>()
1644                )
1645            ]
1646            .into_iter()
1647            .collect()
1648        );
1649    }
1650
1651    #[tokio::test(flavor = "multi_thread")]
1652    async fn test_user_activity() {
1653        let test_db = TestDb::postgres().await;
1654        let db = test_db.db();
1655
1656        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1657        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1658        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1659        let project_1 = db.register_project(user_1).await.unwrap();
1660        db.update_worktree_extensions(
1661            project_1,
1662            1,
1663            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1664        )
1665        .await
1666        .unwrap();
1667        let project_2 = db.register_project(user_2).await.unwrap();
1668        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1669
1670        // User 2 opens a project
1671        let t1 = t0 + Duration::from_secs(10);
1672        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1673            .await
1674            .unwrap();
1675
1676        let t2 = t1 + Duration::from_secs(10);
1677        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1678            .await
1679            .unwrap();
1680
1681        // User 1 joins the project
1682        let t3 = t2 + Duration::from_secs(10);
1683        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1684            .await
1685            .unwrap();
1686
1687        // User 1 opens another project
1688        let t4 = t3 + Duration::from_secs(10);
1689        db.record_user_activity(
1690            t3..t4,
1691            &[
1692                (user_2, project_2),
1693                (user_1, project_2),
1694                (user_1, project_1),
1695            ],
1696        )
1697        .await
1698        .unwrap();
1699
1700        // User 3 joins that project
1701        let t5 = t4 + Duration::from_secs(10);
1702        db.record_user_activity(
1703            t4..t5,
1704            &[
1705                (user_2, project_2),
1706                (user_1, project_2),
1707                (user_1, project_1),
1708                (user_3, project_1),
1709            ],
1710        )
1711        .await
1712        .unwrap();
1713
1714        // User 2 leaves
1715        let t6 = t5 + Duration::from_secs(5);
1716        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1717            .await
1718            .unwrap();
1719
1720        let t7 = t6 + Duration::from_secs(60);
1721        let t8 = t7 + Duration::from_secs(10);
1722        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1723            .await
1724            .unwrap();
1725
1726        assert_eq!(
1727            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1728            &[
1729                UserActivitySummary {
1730                    id: user_1,
1731                    github_login: "user_1".to_string(),
1732                    project_activity: vec![
1733                        ProjectActivitySummary {
1734                            id: project_1,
1735                            duration: Duration::from_secs(25),
1736                            max_collaborators: 2
1737                        },
1738                        ProjectActivitySummary {
1739                            id: project_2,
1740                            duration: Duration::from_secs(30),
1741                            max_collaborators: 2
1742                        }
1743                    ]
1744                },
1745                UserActivitySummary {
1746                    id: user_2,
1747                    github_login: "user_2".to_string(),
1748                    project_activity: vec![ProjectActivitySummary {
1749                        id: project_2,
1750                        duration: Duration::from_secs(50),
1751                        max_collaborators: 2
1752                    }]
1753                },
1754                UserActivitySummary {
1755                    id: user_3,
1756                    github_login: "user_3".to_string(),
1757                    project_activity: vec![ProjectActivitySummary {
1758                        id: project_1,
1759                        duration: Duration::from_secs(15),
1760                        max_collaborators: 2
1761                    }]
1762                },
1763            ]
1764        );
1765
1766        assert_eq!(
1767            db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1768                .await
1769                .unwrap(),
1770            0
1771        );
1772        assert_eq!(
1773            db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1774                .await
1775                .unwrap(),
1776            0
1777        );
1778        assert_eq!(
1779            db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1780                .await
1781                .unwrap(),
1782            1
1783        );
1784        assert_eq!(
1785            db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1786                .await
1787                .unwrap(),
1788            1
1789        );
1790        assert_eq!(
1791            db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1792                .await
1793                .unwrap(),
1794            2
1795        );
1796        assert_eq!(
1797            db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1798                .await
1799                .unwrap(),
1800            2
1801        );
1802        assert_eq!(
1803            db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1804                .await
1805                .unwrap(),
1806            3
1807        );
1808        assert_eq!(
1809            db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1810                .await
1811                .unwrap(),
1812            3
1813        );
1814        assert_eq!(
1815            db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1816                .await
1817                .unwrap(),
1818            1
1819        );
1820        assert_eq!(
1821            db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1822                .await
1823                .unwrap(),
1824            0
1825        );
1826
1827        assert_eq!(
1828            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1829            &[
1830                UserActivityPeriod {
1831                    project_id: project_1,
1832                    start: t3,
1833                    end: t6,
1834                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1835                },
1836                UserActivityPeriod {
1837                    project_id: project_2,
1838                    start: t3,
1839                    end: t5,
1840                    extensions: Default::default(),
1841                },
1842            ]
1843        );
1844        assert_eq!(
1845            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1846            &[
1847                UserActivityPeriod {
1848                    project_id: project_2,
1849                    start: t2,
1850                    end: t5,
1851                    extensions: Default::default(),
1852                },
1853                UserActivityPeriod {
1854                    project_id: project_1,
1855                    start: t3,
1856                    end: t6,
1857                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1858                },
1859                UserActivityPeriod {
1860                    project_id: project_1,
1861                    start: t7,
1862                    end: t8,
1863                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1864                },
1865            ]
1866        );
1867    }
1868
1869    #[tokio::test(flavor = "multi_thread")]
1870    async fn test_recent_channel_messages() {
1871        for test_db in [
1872            TestDb::postgres().await,
1873            TestDb::fake(build_background_executor()),
1874        ] {
1875            let db = test_db.db();
1876            let user = db.create_user("user", None, false).await.unwrap();
1877            let org = db.create_org("org", "org").await.unwrap();
1878            let channel = db.create_org_channel(org, "channel").await.unwrap();
1879            for i in 0..10 {
1880                db.create_channel_message(
1881                    channel,
1882                    user,
1883                    &i.to_string(),
1884                    OffsetDateTime::now_utc(),
1885                    i,
1886                )
1887                .await
1888                .unwrap();
1889            }
1890
1891            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1892            assert_eq!(
1893                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1894                ["5", "6", "7", "8", "9"]
1895            );
1896
1897            let prev_messages = db
1898                .get_channel_messages(channel, 4, Some(messages[0].id))
1899                .await
1900                .unwrap();
1901            assert_eq!(
1902                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1903                ["1", "2", "3", "4"]
1904            );
1905        }
1906    }
1907
1908    #[tokio::test(flavor = "multi_thread")]
1909    async fn test_channel_message_nonces() {
1910        for test_db in [
1911            TestDb::postgres().await,
1912            TestDb::fake(build_background_executor()),
1913        ] {
1914            let db = test_db.db();
1915            let user = db.create_user("user", None, false).await.unwrap();
1916            let org = db.create_org("org", "org").await.unwrap();
1917            let channel = db.create_org_channel(org, "channel").await.unwrap();
1918
1919            let msg1_id = db
1920                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1921                .await
1922                .unwrap();
1923            let msg2_id = db
1924                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1925                .await
1926                .unwrap();
1927            let msg3_id = db
1928                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1929                .await
1930                .unwrap();
1931            let msg4_id = db
1932                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1933                .await
1934                .unwrap();
1935
1936            assert_ne!(msg1_id, msg2_id);
1937            assert_eq!(msg1_id, msg3_id);
1938            assert_eq!(msg2_id, msg4_id);
1939        }
1940    }
1941
1942    #[tokio::test(flavor = "multi_thread")]
1943    async fn test_create_access_tokens() {
1944        let test_db = TestDb::postgres().await;
1945        let db = test_db.db();
1946        let user = db.create_user("the-user", None, false).await.unwrap();
1947
1948        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1949        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1950        assert_eq!(
1951            db.get_access_token_hashes(user).await.unwrap(),
1952            &["h2".to_string(), "h1".to_string()]
1953        );
1954
1955        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1956        assert_eq!(
1957            db.get_access_token_hashes(user).await.unwrap(),
1958            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1959        );
1960
1961        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1962        assert_eq!(
1963            db.get_access_token_hashes(user).await.unwrap(),
1964            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1965        );
1966
1967        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1968        assert_eq!(
1969            db.get_access_token_hashes(user).await.unwrap(),
1970            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1971        );
1972    }
1973
1974    #[test]
1975    fn test_fuzzy_like_string() {
1976        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1977        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1978        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1979    }
1980
1981    #[tokio::test(flavor = "multi_thread")]
1982    async fn test_fuzzy_search_users() {
1983        let test_db = TestDb::postgres().await;
1984        let db = test_db.db();
1985        for github_login in [
1986            "California",
1987            "colorado",
1988            "oregon",
1989            "washington",
1990            "florida",
1991            "delaware",
1992            "rhode-island",
1993        ] {
1994            db.create_user(github_login, None, false).await.unwrap();
1995        }
1996
1997        assert_eq!(
1998            fuzzy_search_user_names(db, "clr").await,
1999            &["colorado", "California"]
2000        );
2001        assert_eq!(
2002            fuzzy_search_user_names(db, "ro").await,
2003            &["rhode-island", "colorado", "oregon"],
2004        );
2005
2006        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2007            db.fuzzy_search_users(query, 10)
2008                .await
2009                .unwrap()
2010                .into_iter()
2011                .map(|user| user.github_login)
2012                .collect::<Vec<_>>()
2013        }
2014    }
2015
2016    #[tokio::test(flavor = "multi_thread")]
2017    async fn test_add_contacts() {
2018        for test_db in [
2019            TestDb::postgres().await,
2020            TestDb::fake(build_background_executor()),
2021        ] {
2022            let db = test_db.db();
2023
2024            let user_1 = db.create_user("user1", None, false).await.unwrap();
2025            let user_2 = db.create_user("user2", None, false).await.unwrap();
2026            let user_3 = db.create_user("user3", None, false).await.unwrap();
2027
2028            // User starts with no contacts
2029            assert_eq!(
2030                db.get_contacts(user_1).await.unwrap(),
2031                vec![Contact::Accepted {
2032                    user_id: user_1,
2033                    should_notify: false
2034                }],
2035            );
2036
2037            // User requests a contact. Both users see the pending request.
2038            db.send_contact_request(user_1, user_2).await.unwrap();
2039            assert!(!db.has_contact(user_1, user_2).await.unwrap());
2040            assert!(!db.has_contact(user_2, user_1).await.unwrap());
2041            assert_eq!(
2042                db.get_contacts(user_1).await.unwrap(),
2043                &[
2044                    Contact::Accepted {
2045                        user_id: user_1,
2046                        should_notify: false
2047                    },
2048                    Contact::Outgoing { user_id: user_2 }
2049                ],
2050            );
2051            assert_eq!(
2052                db.get_contacts(user_2).await.unwrap(),
2053                &[
2054                    Contact::Incoming {
2055                        user_id: user_1,
2056                        should_notify: true
2057                    },
2058                    Contact::Accepted {
2059                        user_id: user_2,
2060                        should_notify: false
2061                    },
2062                ]
2063            );
2064
2065            // User 2 dismisses the contact request notification without accepting or rejecting.
2066            // We shouldn't notify them again.
2067            db.dismiss_contact_notification(user_1, user_2)
2068                .await
2069                .unwrap_err();
2070            db.dismiss_contact_notification(user_2, user_1)
2071                .await
2072                .unwrap();
2073            assert_eq!(
2074                db.get_contacts(user_2).await.unwrap(),
2075                &[
2076                    Contact::Incoming {
2077                        user_id: user_1,
2078                        should_notify: false
2079                    },
2080                    Contact::Accepted {
2081                        user_id: user_2,
2082                        should_notify: false
2083                    },
2084                ]
2085            );
2086
2087            // User can't accept their own contact request
2088            db.respond_to_contact_request(user_1, user_2, true)
2089                .await
2090                .unwrap_err();
2091
2092            // User accepts a contact request. Both users see the contact.
2093            db.respond_to_contact_request(user_2, user_1, true)
2094                .await
2095                .unwrap();
2096            assert_eq!(
2097                db.get_contacts(user_1).await.unwrap(),
2098                &[
2099                    Contact::Accepted {
2100                        user_id: user_1,
2101                        should_notify: false
2102                    },
2103                    Contact::Accepted {
2104                        user_id: user_2,
2105                        should_notify: true
2106                    }
2107                ],
2108            );
2109            assert!(db.has_contact(user_1, user_2).await.unwrap());
2110            assert!(db.has_contact(user_2, user_1).await.unwrap());
2111            assert_eq!(
2112                db.get_contacts(user_2).await.unwrap(),
2113                &[
2114                    Contact::Accepted {
2115                        user_id: user_1,
2116                        should_notify: false,
2117                    },
2118                    Contact::Accepted {
2119                        user_id: user_2,
2120                        should_notify: false,
2121                    },
2122                ]
2123            );
2124
2125            // Users cannot re-request existing contacts.
2126            db.send_contact_request(user_1, user_2).await.unwrap_err();
2127            db.send_contact_request(user_2, user_1).await.unwrap_err();
2128
2129            // Users can't dismiss notifications of them accepting other users' requests.
2130            db.dismiss_contact_notification(user_2, user_1)
2131                .await
2132                .unwrap_err();
2133            assert_eq!(
2134                db.get_contacts(user_1).await.unwrap(),
2135                &[
2136                    Contact::Accepted {
2137                        user_id: user_1,
2138                        should_notify: false
2139                    },
2140                    Contact::Accepted {
2141                        user_id: user_2,
2142                        should_notify: true,
2143                    },
2144                ]
2145            );
2146
2147            // Users can dismiss notifications of other users accepting their requests.
2148            db.dismiss_contact_notification(user_1, user_2)
2149                .await
2150                .unwrap();
2151            assert_eq!(
2152                db.get_contacts(user_1).await.unwrap(),
2153                &[
2154                    Contact::Accepted {
2155                        user_id: user_1,
2156                        should_notify: false
2157                    },
2158                    Contact::Accepted {
2159                        user_id: user_2,
2160                        should_notify: false,
2161                    },
2162                ]
2163            );
2164
2165            // Users send each other concurrent contact requests and
2166            // see that they are immediately accepted.
2167            db.send_contact_request(user_1, user_3).await.unwrap();
2168            db.send_contact_request(user_3, user_1).await.unwrap();
2169            assert_eq!(
2170                db.get_contacts(user_1).await.unwrap(),
2171                &[
2172                    Contact::Accepted {
2173                        user_id: user_1,
2174                        should_notify: false
2175                    },
2176                    Contact::Accepted {
2177                        user_id: user_2,
2178                        should_notify: false,
2179                    },
2180                    Contact::Accepted {
2181                        user_id: user_3,
2182                        should_notify: false
2183                    },
2184                ]
2185            );
2186            assert_eq!(
2187                db.get_contacts(user_3).await.unwrap(),
2188                &[
2189                    Contact::Accepted {
2190                        user_id: user_1,
2191                        should_notify: false
2192                    },
2193                    Contact::Accepted {
2194                        user_id: user_3,
2195                        should_notify: false
2196                    }
2197                ],
2198            );
2199
2200            // User declines a contact request. Both users see that it is gone.
2201            db.send_contact_request(user_2, user_3).await.unwrap();
2202            db.respond_to_contact_request(user_3, user_2, false)
2203                .await
2204                .unwrap();
2205            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2206            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2207            assert_eq!(
2208                db.get_contacts(user_2).await.unwrap(),
2209                &[
2210                    Contact::Accepted {
2211                        user_id: user_1,
2212                        should_notify: false
2213                    },
2214                    Contact::Accepted {
2215                        user_id: user_2,
2216                        should_notify: false
2217                    }
2218                ]
2219            );
2220            assert_eq!(
2221                db.get_contacts(user_3).await.unwrap(),
2222                &[
2223                    Contact::Accepted {
2224                        user_id: user_1,
2225                        should_notify: false
2226                    },
2227                    Contact::Accepted {
2228                        user_id: user_3,
2229                        should_notify: false
2230                    }
2231                ],
2232            );
2233        }
2234    }
2235
2236    #[tokio::test(flavor = "multi_thread")]
2237    async fn test_invite_codes() {
2238        let postgres = TestDb::postgres().await;
2239        let db = postgres.db();
2240        let user1 = db.create_user("user-1", None, false).await.unwrap();
2241
2242        // Initially, user 1 has no invite code
2243        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2244
2245        // Setting invite count to 0 when no code is assigned does not assign a new code
2246        db.set_invite_count(user1, 0).await.unwrap();
2247        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2248
2249        // User 1 creates an invite code that can be used twice.
2250        db.set_invite_count(user1, 2).await.unwrap();
2251        let (invite_code, invite_count) =
2252            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2253        assert_eq!(invite_count, 2);
2254
2255        // User 2 redeems the invite code and becomes a contact of user 1.
2256        let user2 = db
2257            .redeem_invite_code(&invite_code, "user-2", None)
2258            .await
2259            .unwrap();
2260        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2261        assert_eq!(invite_count, 1);
2262        assert_eq!(
2263            db.get_contacts(user1).await.unwrap(),
2264            [
2265                Contact::Accepted {
2266                    user_id: user1,
2267                    should_notify: false
2268                },
2269                Contact::Accepted {
2270                    user_id: user2,
2271                    should_notify: true
2272                }
2273            ]
2274        );
2275        assert_eq!(
2276            db.get_contacts(user2).await.unwrap(),
2277            [
2278                Contact::Accepted {
2279                    user_id: user1,
2280                    should_notify: false
2281                },
2282                Contact::Accepted {
2283                    user_id: user2,
2284                    should_notify: false
2285                }
2286            ]
2287        );
2288
2289        // User 3 redeems the invite code and becomes a contact of user 1.
2290        let user3 = db
2291            .redeem_invite_code(&invite_code, "user-3", None)
2292            .await
2293            .unwrap();
2294        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2295        assert_eq!(invite_count, 0);
2296        assert_eq!(
2297            db.get_contacts(user1).await.unwrap(),
2298            [
2299                Contact::Accepted {
2300                    user_id: user1,
2301                    should_notify: false
2302                },
2303                Contact::Accepted {
2304                    user_id: user2,
2305                    should_notify: true
2306                },
2307                Contact::Accepted {
2308                    user_id: user3,
2309                    should_notify: true
2310                }
2311            ]
2312        );
2313        assert_eq!(
2314            db.get_contacts(user3).await.unwrap(),
2315            [
2316                Contact::Accepted {
2317                    user_id: user1,
2318                    should_notify: false
2319                },
2320                Contact::Accepted {
2321                    user_id: user3,
2322                    should_notify: false
2323                },
2324            ]
2325        );
2326
2327        // Trying to reedem the code for the third time results in an error.
2328        db.redeem_invite_code(&invite_code, "user-4", None)
2329            .await
2330            .unwrap_err();
2331
2332        // Invite count can be updated after the code has been created.
2333        db.set_invite_count(user1, 2).await.unwrap();
2334        let (latest_code, invite_count) =
2335            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2336        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2337        assert_eq!(invite_count, 2);
2338
2339        // User 4 can now redeem the invite code and becomes a contact of user 1.
2340        let user4 = db
2341            .redeem_invite_code(&invite_code, "user-4", None)
2342            .await
2343            .unwrap();
2344        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2345        assert_eq!(invite_count, 1);
2346        assert_eq!(
2347            db.get_contacts(user1).await.unwrap(),
2348            [
2349                Contact::Accepted {
2350                    user_id: user1,
2351                    should_notify: false
2352                },
2353                Contact::Accepted {
2354                    user_id: user2,
2355                    should_notify: true
2356                },
2357                Contact::Accepted {
2358                    user_id: user3,
2359                    should_notify: true
2360                },
2361                Contact::Accepted {
2362                    user_id: user4,
2363                    should_notify: true
2364                }
2365            ]
2366        );
2367        assert_eq!(
2368            db.get_contacts(user4).await.unwrap(),
2369            [
2370                Contact::Accepted {
2371                    user_id: user1,
2372                    should_notify: false
2373                },
2374                Contact::Accepted {
2375                    user_id: user4,
2376                    should_notify: false
2377                },
2378            ]
2379        );
2380
2381        // An existing user cannot redeem invite codes.
2382        db.redeem_invite_code(&invite_code, "user-2", None)
2383            .await
2384            .unwrap_err();
2385        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2386        assert_eq!(invite_count, 1);
2387
2388        // Ensure invited users get invite codes too.
2389        assert_eq!(
2390            db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2391            5
2392        );
2393        assert_eq!(
2394            db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2395            5
2396        );
2397        assert_eq!(
2398            db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2399            5
2400        );
2401    }
2402
2403    pub struct TestDb {
2404        pub db: Option<Arc<dyn Db>>,
2405        pub url: String,
2406    }
2407
2408    impl TestDb {
2409        #[allow(clippy::await_holding_lock)]
2410        pub async fn postgres() -> Self {
2411            lazy_static! {
2412                static ref LOCK: Mutex<()> = Mutex::new(());
2413            }
2414
2415            let _guard = LOCK.lock();
2416            let mut rng = StdRng::from_entropy();
2417            let name = format!("zed-test-{}", rng.gen::<u128>());
2418            let url = format!("postgres://postgres@localhost/{}", name);
2419            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2420            Postgres::create_database(&url)
2421                .await
2422                .expect("failed to create test db");
2423            let db = PostgresDb::new(&url, 5).await.unwrap();
2424            let migrator = Migrator::new(migrations_path).await.unwrap();
2425            migrator.run(&db.pool).await.unwrap();
2426            Self {
2427                db: Some(Arc::new(db)),
2428                url,
2429            }
2430        }
2431
2432        pub fn fake(background: Arc<Background>) -> Self {
2433            Self {
2434                db: Some(Arc::new(FakeDb::new(background))),
2435                url: Default::default(),
2436            }
2437        }
2438
2439        pub fn db(&self) -> &Arc<dyn Db> {
2440            self.db.as_ref().unwrap()
2441        }
2442    }
2443
2444    impl Drop for TestDb {
2445        fn drop(&mut self) {
2446            if let Some(db) = self.db.take() {
2447                futures::executor::block_on(db.teardown(&self.url));
2448            }
2449        }
2450    }
2451
2452    pub struct FakeDb {
2453        background: Arc<Background>,
2454        pub users: Mutex<BTreeMap<UserId, User>>,
2455        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2456        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2457        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2458        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2459        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2460        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2461        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2462        pub contacts: Mutex<Vec<FakeContact>>,
2463        next_channel_message_id: Mutex<i32>,
2464        next_user_id: Mutex<i32>,
2465        next_org_id: Mutex<i32>,
2466        next_channel_id: Mutex<i32>,
2467        next_project_id: Mutex<i32>,
2468    }
2469
2470    #[derive(Debug)]
2471    pub struct FakeContact {
2472        pub requester_id: UserId,
2473        pub responder_id: UserId,
2474        pub accepted: bool,
2475        pub should_notify: bool,
2476    }
2477
2478    impl FakeDb {
2479        pub fn new(background: Arc<Background>) -> Self {
2480            Self {
2481                background,
2482                users: Default::default(),
2483                next_user_id: Mutex::new(0),
2484                projects: Default::default(),
2485                worktree_extensions: Default::default(),
2486                next_project_id: Mutex::new(1),
2487                orgs: Default::default(),
2488                next_org_id: Mutex::new(1),
2489                org_memberships: Default::default(),
2490                channels: Default::default(),
2491                next_channel_id: Mutex::new(1),
2492                channel_memberships: Default::default(),
2493                channel_messages: Default::default(),
2494                next_channel_message_id: Mutex::new(1),
2495                contacts: Default::default(),
2496            }
2497        }
2498    }
2499
2500    #[async_trait]
2501    impl Db for FakeDb {
2502        async fn create_user(
2503            &self,
2504            github_login: &str,
2505            email_address: Option<&str>,
2506            admin: bool,
2507        ) -> Result<UserId> {
2508            self.background.simulate_random_delay().await;
2509
2510            let mut users = self.users.lock();
2511            if let Some(user) = users
2512                .values()
2513                .find(|user| user.github_login == github_login)
2514            {
2515                Ok(user.id)
2516            } else {
2517                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2518                users.insert(
2519                    user_id,
2520                    User {
2521                        id: user_id,
2522                        github_login: github_login.to_string(),
2523                        email_address: email_address.map(str::to_string),
2524                        admin,
2525                        invite_code: None,
2526                        invite_count: 0,
2527                        connected_once: false,
2528                    },
2529                );
2530                Ok(user_id)
2531            }
2532        }
2533
2534        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2535            unimplemented!()
2536        }
2537
2538        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2539            unimplemented!()
2540        }
2541
2542        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2543            unimplemented!()
2544        }
2545
2546        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2547            self.background.simulate_random_delay().await;
2548            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2549        }
2550
2551        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2552            self.background.simulate_random_delay().await;
2553            let users = self.users.lock();
2554            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2555        }
2556
2557        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2558            unimplemented!()
2559        }
2560
2561        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2562            self.background.simulate_random_delay().await;
2563            Ok(self
2564                .users
2565                .lock()
2566                .values()
2567                .find(|user| user.github_login == github_login)
2568                .cloned())
2569        }
2570
2571        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2572            unimplemented!()
2573        }
2574
2575        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2576            self.background.simulate_random_delay().await;
2577            let mut users = self.users.lock();
2578            let mut user = users
2579                .get_mut(&id)
2580                .ok_or_else(|| anyhow!("user not found"))?;
2581            user.connected_once = connected_once;
2582            Ok(())
2583        }
2584
2585        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2586            unimplemented!()
2587        }
2588
2589        // invite codes
2590
2591        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2592            unimplemented!()
2593        }
2594
2595        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2596            self.background.simulate_random_delay().await;
2597            Ok(None)
2598        }
2599
2600        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2601            unimplemented!()
2602        }
2603
2604        async fn redeem_invite_code(
2605            &self,
2606            _code: &str,
2607            _login: &str,
2608            _email_address: Option<&str>,
2609        ) -> Result<UserId> {
2610            unimplemented!()
2611        }
2612
2613        // projects
2614
2615        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2616            self.background.simulate_random_delay().await;
2617            if !self.users.lock().contains_key(&host_user_id) {
2618                Err(anyhow!("no such user"))?;
2619            }
2620
2621            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2622            self.projects.lock().insert(
2623                project_id,
2624                Project {
2625                    id: project_id,
2626                    host_user_id,
2627                    unregistered: false,
2628                },
2629            );
2630            Ok(project_id)
2631        }
2632
2633        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2634            self.background.simulate_random_delay().await;
2635            self.projects
2636                .lock()
2637                .get_mut(&project_id)
2638                .ok_or_else(|| anyhow!("no such project"))?
2639                .unregistered = true;
2640            Ok(())
2641        }
2642
2643        async fn update_worktree_extensions(
2644            &self,
2645            project_id: ProjectId,
2646            worktree_id: u64,
2647            extensions: HashMap<String, u32>,
2648        ) -> Result<()> {
2649            self.background.simulate_random_delay().await;
2650            if !self.projects.lock().contains_key(&project_id) {
2651                Err(anyhow!("no such project"))?;
2652            }
2653
2654            for (extension, count) in extensions {
2655                self.worktree_extensions
2656                    .lock()
2657                    .insert((project_id, worktree_id, extension), count);
2658            }
2659
2660            Ok(())
2661        }
2662
2663        async fn get_project_extensions(
2664            &self,
2665            _project_id: ProjectId,
2666        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2667            unimplemented!()
2668        }
2669
2670        async fn record_user_activity(
2671            &self,
2672            _time_period: Range<OffsetDateTime>,
2673            _active_projects: &[(UserId, ProjectId)],
2674        ) -> Result<()> {
2675            unimplemented!()
2676        }
2677
2678        async fn get_active_user_count(
2679            &self,
2680            _time_period: Range<OffsetDateTime>,
2681            _min_duration: Duration,
2682            _only_collaborative: bool,
2683        ) -> Result<usize> {
2684            unimplemented!()
2685        }
2686
2687        async fn get_top_users_activity_summary(
2688            &self,
2689            _time_period: Range<OffsetDateTime>,
2690            _limit: usize,
2691        ) -> Result<Vec<UserActivitySummary>> {
2692            unimplemented!()
2693        }
2694
2695        async fn get_user_activity_timeline(
2696            &self,
2697            _time_period: Range<OffsetDateTime>,
2698            _user_id: UserId,
2699        ) -> Result<Vec<UserActivityPeriod>> {
2700            unimplemented!()
2701        }
2702
2703        // contacts
2704
2705        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2706            self.background.simulate_random_delay().await;
2707            let mut contacts = vec![Contact::Accepted {
2708                user_id: id,
2709                should_notify: false,
2710            }];
2711
2712            for contact in self.contacts.lock().iter() {
2713                if contact.requester_id == id {
2714                    if contact.accepted {
2715                        contacts.push(Contact::Accepted {
2716                            user_id: contact.responder_id,
2717                            should_notify: contact.should_notify,
2718                        });
2719                    } else {
2720                        contacts.push(Contact::Outgoing {
2721                            user_id: contact.responder_id,
2722                        });
2723                    }
2724                } else if contact.responder_id == id {
2725                    if contact.accepted {
2726                        contacts.push(Contact::Accepted {
2727                            user_id: contact.requester_id,
2728                            should_notify: false,
2729                        });
2730                    } else {
2731                        contacts.push(Contact::Incoming {
2732                            user_id: contact.requester_id,
2733                            should_notify: contact.should_notify,
2734                        });
2735                    }
2736                }
2737            }
2738
2739            contacts.sort_unstable_by_key(|contact| contact.user_id());
2740            Ok(contacts)
2741        }
2742
2743        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2744            self.background.simulate_random_delay().await;
2745            Ok(self.contacts.lock().iter().any(|contact| {
2746                contact.accepted
2747                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2748                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2749            }))
2750        }
2751
2752        async fn send_contact_request(
2753            &self,
2754            requester_id: UserId,
2755            responder_id: UserId,
2756        ) -> Result<()> {
2757            self.background.simulate_random_delay().await;
2758            let mut contacts = self.contacts.lock();
2759            for contact in contacts.iter_mut() {
2760                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2761                    if contact.accepted {
2762                        Err(anyhow!("contact already exists"))?;
2763                    } else {
2764                        Err(anyhow!("contact already requested"))?;
2765                    }
2766                }
2767                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2768                    if contact.accepted {
2769                        Err(anyhow!("contact already exists"))?;
2770                    } else {
2771                        contact.accepted = true;
2772                        contact.should_notify = false;
2773                        return Ok(());
2774                    }
2775                }
2776            }
2777            contacts.push(FakeContact {
2778                requester_id,
2779                responder_id,
2780                accepted: false,
2781                should_notify: true,
2782            });
2783            Ok(())
2784        }
2785
2786        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2787            self.background.simulate_random_delay().await;
2788            self.contacts.lock().retain(|contact| {
2789                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2790            });
2791            Ok(())
2792        }
2793
2794        async fn dismiss_contact_notification(
2795            &self,
2796            user_id: UserId,
2797            contact_user_id: UserId,
2798        ) -> Result<()> {
2799            self.background.simulate_random_delay().await;
2800            let mut contacts = self.contacts.lock();
2801            for contact in contacts.iter_mut() {
2802                if contact.requester_id == contact_user_id
2803                    && contact.responder_id == user_id
2804                    && !contact.accepted
2805                {
2806                    contact.should_notify = false;
2807                    return Ok(());
2808                }
2809                if contact.requester_id == user_id
2810                    && contact.responder_id == contact_user_id
2811                    && contact.accepted
2812                {
2813                    contact.should_notify = false;
2814                    return Ok(());
2815                }
2816            }
2817            Err(anyhow!("no such notification"))?
2818        }
2819
2820        async fn respond_to_contact_request(
2821            &self,
2822            responder_id: UserId,
2823            requester_id: UserId,
2824            accept: bool,
2825        ) -> Result<()> {
2826            self.background.simulate_random_delay().await;
2827            let mut contacts = self.contacts.lock();
2828            for (ix, contact) in contacts.iter_mut().enumerate() {
2829                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2830                    if contact.accepted {
2831                        Err(anyhow!("contact already confirmed"))?;
2832                    }
2833                    if accept {
2834                        contact.accepted = true;
2835                        contact.should_notify = true;
2836                    } else {
2837                        contacts.remove(ix);
2838                    }
2839                    return Ok(());
2840                }
2841            }
2842            Err(anyhow!("no such contact request"))?
2843        }
2844
2845        async fn create_access_token_hash(
2846            &self,
2847            _user_id: UserId,
2848            _access_token_hash: &str,
2849            _max_access_token_count: usize,
2850        ) -> Result<()> {
2851            unimplemented!()
2852        }
2853
2854        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2855            unimplemented!()
2856        }
2857
2858        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2859            unimplemented!()
2860        }
2861
2862        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2863            self.background.simulate_random_delay().await;
2864            let mut orgs = self.orgs.lock();
2865            if orgs.values().any(|org| org.slug == slug) {
2866                Err(anyhow!("org already exists"))?
2867            } else {
2868                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2869                orgs.insert(
2870                    org_id,
2871                    Org {
2872                        id: org_id,
2873                        name: name.to_string(),
2874                        slug: slug.to_string(),
2875                    },
2876                );
2877                Ok(org_id)
2878            }
2879        }
2880
2881        async fn add_org_member(
2882            &self,
2883            org_id: OrgId,
2884            user_id: UserId,
2885            is_admin: bool,
2886        ) -> Result<()> {
2887            self.background.simulate_random_delay().await;
2888            if !self.orgs.lock().contains_key(&org_id) {
2889                Err(anyhow!("org does not exist"))?;
2890            }
2891            if !self.users.lock().contains_key(&user_id) {
2892                Err(anyhow!("user does not exist"))?;
2893            }
2894
2895            self.org_memberships
2896                .lock()
2897                .entry((org_id, user_id))
2898                .or_insert(is_admin);
2899            Ok(())
2900        }
2901
2902        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2903            self.background.simulate_random_delay().await;
2904            if !self.orgs.lock().contains_key(&org_id) {
2905                Err(anyhow!("org does not exist"))?;
2906            }
2907
2908            let mut channels = self.channels.lock();
2909            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2910            channels.insert(
2911                channel_id,
2912                Channel {
2913                    id: channel_id,
2914                    name: name.to_string(),
2915                    owner_id: org_id.0,
2916                    owner_is_user: false,
2917                },
2918            );
2919            Ok(channel_id)
2920        }
2921
2922        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2923            self.background.simulate_random_delay().await;
2924            Ok(self
2925                .channels
2926                .lock()
2927                .values()
2928                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2929                .cloned()
2930                .collect())
2931        }
2932
2933        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2934            self.background.simulate_random_delay().await;
2935            let channels = self.channels.lock();
2936            let memberships = self.channel_memberships.lock();
2937            Ok(channels
2938                .values()
2939                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2940                .cloned()
2941                .collect())
2942        }
2943
2944        async fn can_user_access_channel(
2945            &self,
2946            user_id: UserId,
2947            channel_id: ChannelId,
2948        ) -> Result<bool> {
2949            self.background.simulate_random_delay().await;
2950            Ok(self
2951                .channel_memberships
2952                .lock()
2953                .contains_key(&(channel_id, user_id)))
2954        }
2955
2956        async fn add_channel_member(
2957            &self,
2958            channel_id: ChannelId,
2959            user_id: UserId,
2960            is_admin: bool,
2961        ) -> Result<()> {
2962            self.background.simulate_random_delay().await;
2963            if !self.channels.lock().contains_key(&channel_id) {
2964                Err(anyhow!("channel does not exist"))?;
2965            }
2966            if !self.users.lock().contains_key(&user_id) {
2967                Err(anyhow!("user does not exist"))?;
2968            }
2969
2970            self.channel_memberships
2971                .lock()
2972                .entry((channel_id, user_id))
2973                .or_insert(is_admin);
2974            Ok(())
2975        }
2976
2977        async fn create_channel_message(
2978            &self,
2979            channel_id: ChannelId,
2980            sender_id: UserId,
2981            body: &str,
2982            timestamp: OffsetDateTime,
2983            nonce: u128,
2984        ) -> Result<MessageId> {
2985            self.background.simulate_random_delay().await;
2986            if !self.channels.lock().contains_key(&channel_id) {
2987                Err(anyhow!("channel does not exist"))?;
2988            }
2989            if !self.users.lock().contains_key(&sender_id) {
2990                Err(anyhow!("user does not exist"))?;
2991            }
2992
2993            let mut messages = self.channel_messages.lock();
2994            if let Some(message) = messages
2995                .values()
2996                .find(|message| message.nonce.as_u128() == nonce)
2997            {
2998                Ok(message.id)
2999            } else {
3000                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
3001                messages.insert(
3002                    message_id,
3003                    ChannelMessage {
3004                        id: message_id,
3005                        channel_id,
3006                        sender_id,
3007                        body: body.to_string(),
3008                        sent_at: timestamp,
3009                        nonce: Uuid::from_u128(nonce),
3010                    },
3011                );
3012                Ok(message_id)
3013            }
3014        }
3015
3016        async fn get_channel_messages(
3017            &self,
3018            channel_id: ChannelId,
3019            count: usize,
3020            before_id: Option<MessageId>,
3021        ) -> Result<Vec<ChannelMessage>> {
3022            self.background.simulate_random_delay().await;
3023            let mut messages = self
3024                .channel_messages
3025                .lock()
3026                .values()
3027                .rev()
3028                .filter(|message| {
3029                    message.channel_id == channel_id
3030                        && message.id < before_id.unwrap_or(MessageId::MAX)
3031                })
3032                .take(count)
3033                .cloned()
3034                .collect::<Vec<_>>();
3035            messages.sort_unstable_by_key(|message| message.id);
3036            Ok(messages)
3037        }
3038
3039        async fn teardown(&self, _: &str) {}
3040
3041        #[cfg(test)]
3042        fn as_fake(&self) -> Option<&FakeDb> {
3043            Some(self)
3044        }
3045    }
3046
3047    fn build_background_executor() -> Arc<Background> {
3048        Deterministic::new(0).build_background()
3049    }
3050}