db.rs

   1use std::{cmp, ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use serde::{Deserialize, Serialize};
  10pub use sqlx::postgres::PgPoolOptions as DbOptions;
  11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  12use time::{OffsetDateTime, PrimitiveDateTime};
  13
  14#[async_trait]
  15pub trait Db: Send + Sync {
  16    async fn create_user(
  17        &self,
  18        github_login: &str,
  19        email_address: Option<&str>,
  20        admin: bool,
  21    ) -> Result<UserId>;
  22    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  23    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  24    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  25    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  26    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  27    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  28    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  29    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  30    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  31    async fn destroy_user(&self, id: UserId) -> Result<()>;
  32
  33    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  34    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  35    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  36    async fn redeem_invite_code(
  37        &self,
  38        code: &str,
  39        login: &str,
  40        email_address: Option<&str>,
  41    ) -> Result<UserId>;
  42
  43    /// Registers a new project for the given user.
  44    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  45
  46    /// Unregisters a project for the given project id.
  47    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  48
  49    /// Update file counts by extension for the given project and worktree.
  50    async fn update_worktree_extensions(
  51        &self,
  52        project_id: ProjectId,
  53        worktree_id: u64,
  54        extensions: HashMap<String, u32>,
  55    ) -> Result<()>;
  56
  57    /// Get the file counts on the given project keyed by their worktree and extension.
  58    async fn get_project_extensions(
  59        &self,
  60        project_id: ProjectId,
  61    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  62
  63    /// Record which users have been active in which projects during
  64    /// a given period of time.
  65    async fn record_user_activity(
  66        &self,
  67        time_period: Range<OffsetDateTime>,
  68        active_projects: &[(UserId, ProjectId)],
  69    ) -> Result<()>;
  70
  71    /// Get the number of users who have been active in the given
  72    /// time period for at least the given time duration.
  73    async fn get_active_user_count(
  74        &self,
  75        time_period: Range<OffsetDateTime>,
  76        min_duration: Duration,
  77        only_collaborative: bool,
  78    ) -> Result<usize>;
  79
  80    /// Get the users that have been most active during the given time period,
  81    /// along with the amount of time they have been active in each project.
  82    async fn get_top_users_activity_summary(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        max_user_count: usize,
  86    ) -> Result<Vec<UserActivitySummary>>;
  87
  88    /// Get the project activity for the given user and time period.
  89    async fn get_user_activity_timeline(
  90        &self,
  91        time_period: Range<OffsetDateTime>,
  92        user_id: UserId,
  93    ) -> Result<Vec<UserActivityPeriod>>;
  94
  95    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  96    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  97    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  98    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  99    async fn dismiss_contact_notification(
 100        &self,
 101        responder_id: UserId,
 102        requester_id: UserId,
 103    ) -> Result<()>;
 104    async fn respond_to_contact_request(
 105        &self,
 106        responder_id: UserId,
 107        requester_id: UserId,
 108        accept: bool,
 109    ) -> Result<()>;
 110
 111    async fn create_access_token_hash(
 112        &self,
 113        user_id: UserId,
 114        access_token_hash: &str,
 115        max_access_token_count: usize,
 116    ) -> Result<()>;
 117    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 118    #[cfg(any(test, feature = "seed-support"))]
 119
 120    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 121    #[cfg(any(test, feature = "seed-support"))]
 122    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 123    #[cfg(any(test, feature = "seed-support"))]
 124    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 127    #[cfg(any(test, feature = "seed-support"))]
 128
 129    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 130    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 131    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 132        -> Result<bool>;
 133    #[cfg(any(test, feature = "seed-support"))]
 134    async fn add_channel_member(
 135        &self,
 136        channel_id: ChannelId,
 137        user_id: UserId,
 138        is_admin: bool,
 139    ) -> Result<()>;
 140    async fn create_channel_message(
 141        &self,
 142        channel_id: ChannelId,
 143        sender_id: UserId,
 144        body: &str,
 145        timestamp: OffsetDateTime,
 146        nonce: u128,
 147    ) -> Result<MessageId>;
 148    async fn get_channel_messages(
 149        &self,
 150        channel_id: ChannelId,
 151        count: usize,
 152        before_id: Option<MessageId>,
 153    ) -> Result<Vec<ChannelMessage>>;
 154    #[cfg(test)]
 155    async fn teardown(&self, url: &str);
 156    #[cfg(test)]
 157    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 158}
 159
 160pub struct PostgresDb {
 161    pool: sqlx::PgPool,
 162}
 163
 164impl PostgresDb {
 165    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 166        let pool = DbOptions::new()
 167            .max_connections(max_connections)
 168            .connect(&url)
 169            .await
 170            .context("failed to connect to postgres database")?;
 171        Ok(Self { pool })
 172    }
 173}
 174
 175#[async_trait]
 176impl Db for PostgresDb {
 177    // users
 178
 179    async fn create_user(
 180        &self,
 181        github_login: &str,
 182        email_address: Option<&str>,
 183        admin: bool,
 184    ) -> Result<UserId> {
 185        let query = "
 186            INSERT INTO users (github_login, email_address, admin)
 187            VALUES ($1, $2, $3)
 188            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 189            RETURNING id
 190        ";
 191        Ok(sqlx::query_scalar(query)
 192            .bind(github_login)
 193            .bind(email_address)
 194            .bind(admin)
 195            .fetch_one(&self.pool)
 196            .await
 197            .map(UserId)?)
 198    }
 199
 200    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 201        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 202        Ok(sqlx::query_as(query)
 203            .bind(limit as i32)
 204            .bind((page * limit) as i32)
 205            .fetch_all(&self.pool)
 206            .await?)
 207    }
 208
 209    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 210        let mut query = QueryBuilder::new(
 211            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 212        );
 213        query.push_values(
 214            users,
 215            |mut query, (github_login, email_address, invite_count)| {
 216                query
 217                    .push_bind(github_login)
 218                    .push_bind(email_address)
 219                    .push_bind(false)
 220                    .push_bind(random_invite_code())
 221                    .push_bind(invite_count as i32);
 222            },
 223        );
 224        query.push(
 225            "
 226            ON CONFLICT (github_login) DO UPDATE SET
 227                github_login = excluded.github_login,
 228                invite_count = excluded.invite_count,
 229                invite_code = CASE WHEN users.invite_code IS NULL
 230                                   THEN excluded.invite_code
 231                                   ELSE users.invite_code
 232                              END
 233            RETURNING id
 234            ",
 235        );
 236
 237        let rows = query.build().fetch_all(&self.pool).await?;
 238        Ok(rows
 239            .into_iter()
 240            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 241            .collect())
 242    }
 243
 244    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 245        let like_string = fuzzy_like_string(name_query);
 246        let query = "
 247            SELECT users.*
 248            FROM users
 249            WHERE github_login ILIKE $1
 250            ORDER BY github_login <-> $2
 251            LIMIT $3
 252        ";
 253        Ok(sqlx::query_as(query)
 254            .bind(like_string)
 255            .bind(name_query)
 256            .bind(limit as i32)
 257            .fetch_all(&self.pool)
 258            .await?)
 259    }
 260
 261    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 262        let users = self.get_users_by_ids(vec![id]).await?;
 263        Ok(users.into_iter().next())
 264    }
 265
 266    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 267        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 268        let query = "
 269            SELECT users.*
 270            FROM users
 271            WHERE users.id = ANY ($1)
 272        ";
 273        Ok(sqlx::query_as(query)
 274            .bind(&ids)
 275            .fetch_all(&self.pool)
 276            .await?)
 277    }
 278
 279    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 280        let query = format!(
 281            "
 282            SELECT users.*
 283            FROM users
 284            WHERE invite_count = 0
 285            AND inviter_id IS{} NULL
 286            ",
 287            if invited_by_another_user { " NOT" } else { "" }
 288        );
 289
 290        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 291    }
 292
 293    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 294        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 295        Ok(sqlx::query_as(query)
 296            .bind(github_login)
 297            .fetch_optional(&self.pool)
 298            .await?)
 299    }
 300
 301    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 302        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 303        Ok(sqlx::query(query)
 304            .bind(is_admin)
 305            .bind(id.0)
 306            .execute(&self.pool)
 307            .await
 308            .map(drop)?)
 309    }
 310
 311    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 312        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 313        Ok(sqlx::query(query)
 314            .bind(connected_once)
 315            .bind(id.0)
 316            .execute(&self.pool)
 317            .await
 318            .map(drop)?)
 319    }
 320
 321    async fn destroy_user(&self, id: UserId) -> Result<()> {
 322        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 323        sqlx::query(query)
 324            .bind(id.0)
 325            .execute(&self.pool)
 326            .await
 327            .map(drop)?;
 328        let query = "DELETE FROM users WHERE id = $1;";
 329        Ok(sqlx::query(query)
 330            .bind(id.0)
 331            .execute(&self.pool)
 332            .await
 333            .map(drop)?)
 334    }
 335
 336    // invite codes
 337
 338    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 339        let mut tx = self.pool.begin().await?;
 340        if count > 0 {
 341            sqlx::query(
 342                "
 343                UPDATE users
 344                SET invite_code = $1
 345                WHERE id = $2 AND invite_code IS NULL
 346            ",
 347            )
 348            .bind(random_invite_code())
 349            .bind(id)
 350            .execute(&mut tx)
 351            .await?;
 352        }
 353
 354        sqlx::query(
 355            "
 356            UPDATE users
 357            SET invite_count = $1
 358            WHERE id = $2
 359            ",
 360        )
 361        .bind(count as i32)
 362        .bind(id)
 363        .execute(&mut tx)
 364        .await?;
 365        tx.commit().await?;
 366        Ok(())
 367    }
 368
 369    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 370        let result: Option<(String, i32)> = sqlx::query_as(
 371            "
 372                SELECT invite_code, invite_count
 373                FROM users
 374                WHERE id = $1 AND invite_code IS NOT NULL 
 375            ",
 376        )
 377        .bind(id)
 378        .fetch_optional(&self.pool)
 379        .await?;
 380        if let Some((code, count)) = result {
 381            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 382        } else {
 383            Ok(None)
 384        }
 385    }
 386
 387    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 388        sqlx::query_as(
 389            "
 390                SELECT *
 391                FROM users
 392                WHERE invite_code = $1
 393            ",
 394        )
 395        .bind(code)
 396        .fetch_optional(&self.pool)
 397        .await?
 398        .ok_or_else(|| {
 399            Error::Http(
 400                StatusCode::NOT_FOUND,
 401                "that invite code does not exist".to_string(),
 402            )
 403        })
 404    }
 405
 406    async fn redeem_invite_code(
 407        &self,
 408        code: &str,
 409        login: &str,
 410        email_address: Option<&str>,
 411    ) -> Result<UserId> {
 412        let mut tx = self.pool.begin().await?;
 413
 414        let inviter_id: Option<UserId> = sqlx::query_scalar(
 415            "
 416                UPDATE users
 417                SET invite_count = invite_count - 1
 418                WHERE
 419                    invite_code = $1 AND
 420                    invite_count > 0
 421                RETURNING id
 422            ",
 423        )
 424        .bind(code)
 425        .fetch_optional(&mut tx)
 426        .await?;
 427
 428        let inviter_id = match inviter_id {
 429            Some(inviter_id) => inviter_id,
 430            None => {
 431                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 432                    .bind(code)
 433                    .fetch_optional(&mut tx)
 434                    .await?
 435                    .is_some()
 436                {
 437                    Err(Error::Http(
 438                        StatusCode::UNAUTHORIZED,
 439                        "no invites remaining".to_string(),
 440                    ))?
 441                } else {
 442                    Err(Error::Http(
 443                        StatusCode::NOT_FOUND,
 444                        "invite code not found".to_string(),
 445                    ))?
 446                }
 447            }
 448        };
 449
 450        let invitee_id = sqlx::query_scalar(
 451            "
 452                INSERT INTO users
 453                    (github_login, email_address, admin, inviter_id, invite_code, invite_count)
 454                VALUES
 455                    ($1, $2, 'f', $3, $4, $5)
 456                RETURNING id
 457            ",
 458        )
 459        .bind(login)
 460        .bind(email_address)
 461        .bind(inviter_id)
 462        .bind(random_invite_code())
 463        .bind(5)
 464        .fetch_one(&mut tx)
 465        .await
 466        .map(UserId)?;
 467
 468        sqlx::query(
 469            "
 470                INSERT INTO contacts
 471                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 472                VALUES
 473                    ($1, $2, 't', 't', 't')
 474            ",
 475        )
 476        .bind(inviter_id)
 477        .bind(invitee_id)
 478        .execute(&mut tx)
 479        .await?;
 480
 481        tx.commit().await?;
 482        Ok(invitee_id)
 483    }
 484
 485    // projects
 486
 487    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 488        Ok(sqlx::query_scalar(
 489            "
 490            INSERT INTO projects(host_user_id)
 491            VALUES ($1)
 492            RETURNING id
 493            ",
 494        )
 495        .bind(host_user_id)
 496        .fetch_one(&self.pool)
 497        .await
 498        .map(ProjectId)?)
 499    }
 500
 501    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 502        sqlx::query(
 503            "
 504            UPDATE projects
 505            SET unregistered = 't'
 506            WHERE id = $1
 507            ",
 508        )
 509        .bind(project_id)
 510        .execute(&self.pool)
 511        .await?;
 512        Ok(())
 513    }
 514
 515    async fn update_worktree_extensions(
 516        &self,
 517        project_id: ProjectId,
 518        worktree_id: u64,
 519        extensions: HashMap<String, u32>,
 520    ) -> Result<()> {
 521        if extensions.is_empty() {
 522            return Ok(());
 523        }
 524
 525        let mut query = QueryBuilder::new(
 526            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 527        );
 528        query.push_values(extensions, |mut query, (extension, count)| {
 529            query
 530                .push_bind(project_id)
 531                .push_bind(worktree_id as i32)
 532                .push_bind(extension)
 533                .push_bind(count as i32);
 534        });
 535        query.push(
 536            "
 537            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 538            count = excluded.count
 539            ",
 540        );
 541        query.build().execute(&self.pool).await?;
 542
 543        Ok(())
 544    }
 545
 546    async fn get_project_extensions(
 547        &self,
 548        project_id: ProjectId,
 549    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 550        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 551        struct WorktreeExtension {
 552            worktree_id: i32,
 553            extension: String,
 554            count: i32,
 555        }
 556
 557        let query = "
 558            SELECT worktree_id, extension, count
 559            FROM worktree_extensions
 560            WHERE project_id = $1
 561        ";
 562        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 563            .bind(&project_id)
 564            .fetch_all(&self.pool)
 565            .await?;
 566
 567        let mut extension_counts = HashMap::default();
 568        for count in counts {
 569            extension_counts
 570                .entry(count.worktree_id as u64)
 571                .or_insert(HashMap::default())
 572                .insert(count.extension, count.count as usize);
 573        }
 574        Ok(extension_counts)
 575    }
 576
 577    async fn record_user_activity(
 578        &self,
 579        time_period: Range<OffsetDateTime>,
 580        projects: &[(UserId, ProjectId)],
 581    ) -> Result<()> {
 582        let query = "
 583            INSERT INTO project_activity_periods
 584            (ended_at, duration_millis, user_id, project_id)
 585            VALUES
 586            ($1, $2, $3, $4);
 587        ";
 588
 589        let mut tx = self.pool.begin().await?;
 590        let duration_millis =
 591            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 592        for (user_id, project_id) in projects {
 593            sqlx::query(query)
 594                .bind(time_period.end)
 595                .bind(duration_millis)
 596                .bind(user_id)
 597                .bind(project_id)
 598                .execute(&mut tx)
 599                .await?;
 600        }
 601        tx.commit().await?;
 602
 603        Ok(())
 604    }
 605
 606    async fn get_active_user_count(
 607        &self,
 608        time_period: Range<OffsetDateTime>,
 609        min_duration: Duration,
 610        only_collaborative: bool,
 611    ) -> Result<usize> {
 612        let mut with_clause = String::new();
 613        with_clause.push_str("WITH\n");
 614        with_clause.push_str(
 615            "
 616            project_durations AS (
 617                SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 618                FROM project_activity_periods
 619                WHERE $1 < ended_at AND ended_at <= $2
 620                GROUP BY user_id, project_id
 621            ),
 622            ",
 623        );
 624        with_clause.push_str(
 625            "
 626            project_collaborators as (
 627                SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 628                FROM project_durations
 629                GROUP BY project_id
 630            ),
 631            ",
 632        );
 633
 634        if only_collaborative {
 635            with_clause.push_str(
 636                "
 637                user_durations AS (
 638                    SELECT user_id, SUM(project_duration) as total_duration
 639                    FROM project_durations, project_collaborators
 640                    WHERE
 641                        project_durations.project_id = project_collaborators.project_id AND
 642                        max_collaborators > 1
 643                    GROUP BY user_id
 644                    ORDER BY total_duration DESC
 645                    LIMIT $3
 646                )
 647                ",
 648            );
 649        } else {
 650            with_clause.push_str(
 651                "
 652                user_durations AS (
 653                    SELECT user_id, SUM(project_duration) as total_duration
 654                    FROM project_durations
 655                    GROUP BY user_id
 656                    ORDER BY total_duration DESC
 657                    LIMIT $3
 658                )
 659                ",
 660            );
 661        }
 662
 663        let query = format!(
 664            "
 665            {with_clause}
 666            SELECT count(user_durations.user_id)
 667            FROM user_durations
 668            WHERE user_durations.total_duration >= $3
 669            "
 670        );
 671
 672        let count: i64 = sqlx::query_scalar(&query)
 673            .bind(time_period.start)
 674            .bind(time_period.end)
 675            .bind(min_duration.as_millis() as i64)
 676            .fetch_one(&self.pool)
 677            .await?;
 678        Ok(count as usize)
 679    }
 680
 681    async fn get_top_users_activity_summary(
 682        &self,
 683        time_period: Range<OffsetDateTime>,
 684        max_user_count: usize,
 685    ) -> Result<Vec<UserActivitySummary>> {
 686        let query = "
 687            WITH
 688                project_durations AS (
 689                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 690                    FROM project_activity_periods
 691                    WHERE $1 < ended_at AND ended_at <= $2
 692                    GROUP BY user_id, project_id
 693                ),
 694                user_durations AS (
 695                    SELECT user_id, SUM(project_duration) as total_duration
 696                    FROM project_durations
 697                    GROUP BY user_id
 698                    ORDER BY total_duration DESC
 699                    LIMIT $3
 700                ),
 701                project_collaborators as (
 702                    SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
 703                    FROM project_durations
 704                    GROUP BY project_id
 705                )
 706            SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
 707            FROM user_durations, project_durations, project_collaborators, users
 708            WHERE
 709                user_durations.user_id = project_durations.user_id AND
 710                user_durations.user_id = users.id AND
 711                project_durations.project_id = project_collaborators.project_id
 712            ORDER BY total_duration DESC, user_id ASC
 713        ";
 714
 715        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
 716            .bind(time_period.start)
 717            .bind(time_period.end)
 718            .bind(max_user_count as i32)
 719            .fetch(&self.pool);
 720
 721        let mut result = Vec::<UserActivitySummary>::new();
 722        while let Some(row) = rows.next().await {
 723            let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
 724            let project_id = project_id;
 725            let duration = Duration::from_millis(duration_millis as u64);
 726            let project_activity = ProjectActivitySummary {
 727                id: project_id,
 728                duration,
 729                max_collaborators: project_collaborators as usize,
 730            };
 731            if let Some(last_summary) = result.last_mut() {
 732                if last_summary.id == user_id {
 733                    last_summary.project_activity.push(project_activity);
 734                    continue;
 735                }
 736            }
 737            result.push(UserActivitySummary {
 738                id: user_id,
 739                project_activity: vec![project_activity],
 740                github_login,
 741            });
 742        }
 743
 744        Ok(result)
 745    }
 746
 747    async fn get_user_activity_timeline(
 748        &self,
 749        time_period: Range<OffsetDateTime>,
 750        user_id: UserId,
 751    ) -> Result<Vec<UserActivityPeriod>> {
 752        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 753
 754        let query = "
 755            SELECT
 756                project_activity_periods.ended_at,
 757                project_activity_periods.duration_millis,
 758                project_activity_periods.project_id,
 759                worktree_extensions.extension,
 760                worktree_extensions.count
 761            FROM project_activity_periods
 762            LEFT OUTER JOIN
 763                worktree_extensions
 764            ON
 765                project_activity_periods.project_id = worktree_extensions.project_id
 766            WHERE
 767                project_activity_periods.user_id = $1 AND
 768                $2 < project_activity_periods.ended_at AND
 769                project_activity_periods.ended_at <= $3
 770            ORDER BY project_activity_periods.id ASC
 771        ";
 772
 773        let mut rows = sqlx::query_as::<
 774            _,
 775            (
 776                PrimitiveDateTime,
 777                i32,
 778                ProjectId,
 779                Option<String>,
 780                Option<i32>,
 781            ),
 782        >(query)
 783        .bind(user_id)
 784        .bind(time_period.start)
 785        .bind(time_period.end)
 786        .fetch(&self.pool);
 787
 788        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 789        while let Some(row) = rows.next().await {
 790            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 791            let ended_at = ended_at.assume_utc();
 792            let duration = Duration::from_millis(duration_millis as u64);
 793            let started_at = ended_at - duration;
 794            let project_time_periods = time_periods.entry(project_id).or_default();
 795
 796            if let Some(prev_duration) = project_time_periods.last_mut() {
 797                if started_at <= prev_duration.end + COALESCE_THRESHOLD
 798                    && ended_at >= prev_duration.start
 799                {
 800                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
 801                } else {
 802                    project_time_periods.push(UserActivityPeriod {
 803                        project_id,
 804                        start: started_at,
 805                        end: ended_at,
 806                        extensions: Default::default(),
 807                    });
 808                }
 809            } else {
 810                project_time_periods.push(UserActivityPeriod {
 811                    project_id,
 812                    start: started_at,
 813                    end: ended_at,
 814                    extensions: Default::default(),
 815                });
 816            }
 817
 818            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 819                project_time_periods
 820                    .last_mut()
 821                    .unwrap()
 822                    .extensions
 823                    .insert(extension, extension_count as usize);
 824            }
 825        }
 826
 827        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 828        durations.sort_unstable_by_key(|duration| duration.start);
 829        Ok(durations)
 830    }
 831
 832    // contacts
 833
 834    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 835        let query = "
 836            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 837            FROM contacts
 838            WHERE user_id_a = $1 OR user_id_b = $1;
 839        ";
 840
 841        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 842            .bind(user_id)
 843            .fetch(&self.pool);
 844
 845        let mut contacts = vec![Contact::Accepted {
 846            user_id,
 847            should_notify: false,
 848        }];
 849        while let Some(row) = rows.next().await {
 850            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 851
 852            if user_id_a == user_id {
 853                if accepted {
 854                    contacts.push(Contact::Accepted {
 855                        user_id: user_id_b,
 856                        should_notify: should_notify && a_to_b,
 857                    });
 858                } else if a_to_b {
 859                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 860                } else {
 861                    contacts.push(Contact::Incoming {
 862                        user_id: user_id_b,
 863                        should_notify,
 864                    });
 865                }
 866            } else {
 867                if accepted {
 868                    contacts.push(Contact::Accepted {
 869                        user_id: user_id_a,
 870                        should_notify: should_notify && !a_to_b,
 871                    });
 872                } else if a_to_b {
 873                    contacts.push(Contact::Incoming {
 874                        user_id: user_id_a,
 875                        should_notify,
 876                    });
 877                } else {
 878                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 879                }
 880            }
 881        }
 882
 883        contacts.sort_unstable_by_key(|contact| contact.user_id());
 884
 885        Ok(contacts)
 886    }
 887
 888    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 889        let (id_a, id_b) = if user_id_1 < user_id_2 {
 890            (user_id_1, user_id_2)
 891        } else {
 892            (user_id_2, user_id_1)
 893        };
 894
 895        let query = "
 896            SELECT 1 FROM contacts
 897            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 898            LIMIT 1
 899        ";
 900        Ok(sqlx::query_scalar::<_, i32>(query)
 901            .bind(id_a.0)
 902            .bind(id_b.0)
 903            .fetch_optional(&self.pool)
 904            .await?
 905            .is_some())
 906    }
 907
 908    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 909        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 910            (sender_id, receiver_id, true)
 911        } else {
 912            (receiver_id, sender_id, false)
 913        };
 914        let query = "
 915            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 916            VALUES ($1, $2, $3, 'f', 't')
 917            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 918            SET
 919                accepted = 't',
 920                should_notify = 'f'
 921            WHERE
 922                NOT contacts.accepted AND
 923                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 924                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 925        ";
 926        let result = sqlx::query(query)
 927            .bind(id_a.0)
 928            .bind(id_b.0)
 929            .bind(a_to_b)
 930            .execute(&self.pool)
 931            .await?;
 932
 933        if result.rows_affected() == 1 {
 934            Ok(())
 935        } else {
 936            Err(anyhow!("contact already requested"))?
 937        }
 938    }
 939
 940    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 941        let (id_a, id_b) = if responder_id < requester_id {
 942            (responder_id, requester_id)
 943        } else {
 944            (requester_id, responder_id)
 945        };
 946        let query = "
 947            DELETE FROM contacts
 948            WHERE user_id_a = $1 AND user_id_b = $2;
 949        ";
 950        let result = sqlx::query(query)
 951            .bind(id_a.0)
 952            .bind(id_b.0)
 953            .execute(&self.pool)
 954            .await?;
 955
 956        if result.rows_affected() == 1 {
 957            Ok(())
 958        } else {
 959            Err(anyhow!("no such contact"))?
 960        }
 961    }
 962
 963    async fn dismiss_contact_notification(
 964        &self,
 965        user_id: UserId,
 966        contact_user_id: UserId,
 967    ) -> Result<()> {
 968        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 969            (user_id, contact_user_id, true)
 970        } else {
 971            (contact_user_id, user_id, false)
 972        };
 973
 974        let query = "
 975            UPDATE contacts
 976            SET should_notify = 'f'
 977            WHERE
 978                user_id_a = $1 AND user_id_b = $2 AND
 979                (
 980                    (a_to_b = $3 AND accepted) OR
 981                    (a_to_b != $3 AND NOT accepted)
 982                );
 983        ";
 984
 985        let result = sqlx::query(query)
 986            .bind(id_a.0)
 987            .bind(id_b.0)
 988            .bind(a_to_b)
 989            .execute(&self.pool)
 990            .await?;
 991
 992        if result.rows_affected() == 0 {
 993            Err(anyhow!("no such contact request"))?;
 994        }
 995
 996        Ok(())
 997    }
 998
 999    async fn respond_to_contact_request(
1000        &self,
1001        responder_id: UserId,
1002        requester_id: UserId,
1003        accept: bool,
1004    ) -> Result<()> {
1005        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1006            (responder_id, requester_id, false)
1007        } else {
1008            (requester_id, responder_id, true)
1009        };
1010        let result = if accept {
1011            let query = "
1012                UPDATE contacts
1013                SET accepted = 't', should_notify = 't'
1014                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1015            ";
1016            sqlx::query(query)
1017                .bind(id_a.0)
1018                .bind(id_b.0)
1019                .bind(a_to_b)
1020                .execute(&self.pool)
1021                .await?
1022        } else {
1023            let query = "
1024                DELETE FROM contacts
1025                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1026            ";
1027            sqlx::query(query)
1028                .bind(id_a.0)
1029                .bind(id_b.0)
1030                .bind(a_to_b)
1031                .execute(&self.pool)
1032                .await?
1033        };
1034        if result.rows_affected() == 1 {
1035            Ok(())
1036        } else {
1037            Err(anyhow!("no such contact request"))?
1038        }
1039    }
1040
1041    // access tokens
1042
1043    async fn create_access_token_hash(
1044        &self,
1045        user_id: UserId,
1046        access_token_hash: &str,
1047        max_access_token_count: usize,
1048    ) -> Result<()> {
1049        let insert_query = "
1050            INSERT INTO access_tokens (user_id, hash)
1051            VALUES ($1, $2);
1052        ";
1053        let cleanup_query = "
1054            DELETE FROM access_tokens
1055            WHERE id IN (
1056                SELECT id from access_tokens
1057                WHERE user_id = $1
1058                ORDER BY id DESC
1059                OFFSET $3
1060            )
1061        ";
1062
1063        let mut tx = self.pool.begin().await?;
1064        sqlx::query(insert_query)
1065            .bind(user_id.0)
1066            .bind(access_token_hash)
1067            .execute(&mut tx)
1068            .await?;
1069        sqlx::query(cleanup_query)
1070            .bind(user_id.0)
1071            .bind(access_token_hash)
1072            .bind(max_access_token_count as i32)
1073            .execute(&mut tx)
1074            .await?;
1075        Ok(tx.commit().await?)
1076    }
1077
1078    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1079        let query = "
1080            SELECT hash
1081            FROM access_tokens
1082            WHERE user_id = $1
1083            ORDER BY id DESC
1084        ";
1085        Ok(sqlx::query_scalar(query)
1086            .bind(user_id.0)
1087            .fetch_all(&self.pool)
1088            .await?)
1089    }
1090
1091    // orgs
1092
1093    #[allow(unused)] // Help rust-analyzer
1094    #[cfg(any(test, feature = "seed-support"))]
1095    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1096        let query = "
1097            SELECT *
1098            FROM orgs
1099            WHERE slug = $1
1100        ";
1101        Ok(sqlx::query_as(query)
1102            .bind(slug)
1103            .fetch_optional(&self.pool)
1104            .await?)
1105    }
1106
1107    #[cfg(any(test, feature = "seed-support"))]
1108    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1109        let query = "
1110            INSERT INTO orgs (name, slug)
1111            VALUES ($1, $2)
1112            RETURNING id
1113        ";
1114        Ok(sqlx::query_scalar(query)
1115            .bind(name)
1116            .bind(slug)
1117            .fetch_one(&self.pool)
1118            .await
1119            .map(OrgId)?)
1120    }
1121
1122    #[cfg(any(test, feature = "seed-support"))]
1123    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1124        let query = "
1125            INSERT INTO org_memberships (org_id, user_id, admin)
1126            VALUES ($1, $2, $3)
1127            ON CONFLICT DO NOTHING
1128        ";
1129        Ok(sqlx::query(query)
1130            .bind(org_id.0)
1131            .bind(user_id.0)
1132            .bind(is_admin)
1133            .execute(&self.pool)
1134            .await
1135            .map(drop)?)
1136    }
1137
1138    // channels
1139
1140    #[cfg(any(test, feature = "seed-support"))]
1141    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1142        let query = "
1143            INSERT INTO channels (owner_id, owner_is_user, name)
1144            VALUES ($1, false, $2)
1145            RETURNING id
1146        ";
1147        Ok(sqlx::query_scalar(query)
1148            .bind(org_id.0)
1149            .bind(name)
1150            .fetch_one(&self.pool)
1151            .await
1152            .map(ChannelId)?)
1153    }
1154
1155    #[allow(unused)] // Help rust-analyzer
1156    #[cfg(any(test, feature = "seed-support"))]
1157    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1158        let query = "
1159            SELECT *
1160            FROM channels
1161            WHERE
1162                channels.owner_is_user = false AND
1163                channels.owner_id = $1
1164        ";
1165        Ok(sqlx::query_as(query)
1166            .bind(org_id.0)
1167            .fetch_all(&self.pool)
1168            .await?)
1169    }
1170
1171    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1172        let query = "
1173            SELECT
1174                channels.*
1175            FROM
1176                channel_memberships, channels
1177            WHERE
1178                channel_memberships.user_id = $1 AND
1179                channel_memberships.channel_id = channels.id
1180        ";
1181        Ok(sqlx::query_as(query)
1182            .bind(user_id.0)
1183            .fetch_all(&self.pool)
1184            .await?)
1185    }
1186
1187    async fn can_user_access_channel(
1188        &self,
1189        user_id: UserId,
1190        channel_id: ChannelId,
1191    ) -> Result<bool> {
1192        let query = "
1193            SELECT id
1194            FROM channel_memberships
1195            WHERE user_id = $1 AND channel_id = $2
1196            LIMIT 1
1197        ";
1198        Ok(sqlx::query_scalar::<_, i32>(query)
1199            .bind(user_id.0)
1200            .bind(channel_id.0)
1201            .fetch_optional(&self.pool)
1202            .await
1203            .map(|e| e.is_some())?)
1204    }
1205
1206    #[cfg(any(test, feature = "seed-support"))]
1207    async fn add_channel_member(
1208        &self,
1209        channel_id: ChannelId,
1210        user_id: UserId,
1211        is_admin: bool,
1212    ) -> Result<()> {
1213        let query = "
1214            INSERT INTO channel_memberships (channel_id, user_id, admin)
1215            VALUES ($1, $2, $3)
1216            ON CONFLICT DO NOTHING
1217        ";
1218        Ok(sqlx::query(query)
1219            .bind(channel_id.0)
1220            .bind(user_id.0)
1221            .bind(is_admin)
1222            .execute(&self.pool)
1223            .await
1224            .map(drop)?)
1225    }
1226
1227    // messages
1228
1229    async fn create_channel_message(
1230        &self,
1231        channel_id: ChannelId,
1232        sender_id: UserId,
1233        body: &str,
1234        timestamp: OffsetDateTime,
1235        nonce: u128,
1236    ) -> Result<MessageId> {
1237        let query = "
1238            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1239            VALUES ($1, $2, $3, $4, $5)
1240            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1241            RETURNING id
1242        ";
1243        Ok(sqlx::query_scalar(query)
1244            .bind(channel_id.0)
1245            .bind(sender_id.0)
1246            .bind(body)
1247            .bind(timestamp)
1248            .bind(Uuid::from_u128(nonce))
1249            .fetch_one(&self.pool)
1250            .await
1251            .map(MessageId)?)
1252    }
1253
1254    async fn get_channel_messages(
1255        &self,
1256        channel_id: ChannelId,
1257        count: usize,
1258        before_id: Option<MessageId>,
1259    ) -> Result<Vec<ChannelMessage>> {
1260        let query = r#"
1261            SELECT * FROM (
1262                SELECT
1263                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1264                FROM
1265                    channel_messages
1266                WHERE
1267                    channel_id = $1 AND
1268                    id < $2
1269                ORDER BY id DESC
1270                LIMIT $3
1271            ) as recent_messages
1272            ORDER BY id ASC
1273        "#;
1274        Ok(sqlx::query_as(query)
1275            .bind(channel_id.0)
1276            .bind(before_id.unwrap_or(MessageId::MAX))
1277            .bind(count as i64)
1278            .fetch_all(&self.pool)
1279            .await?)
1280    }
1281
1282    #[cfg(test)]
1283    async fn teardown(&self, url: &str) {
1284        use util::ResultExt;
1285
1286        let query = "
1287            SELECT pg_terminate_backend(pg_stat_activity.pid)
1288            FROM pg_stat_activity
1289            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1290        ";
1291        sqlx::query(query).execute(&self.pool).await.log_err();
1292        self.pool.close().await;
1293        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1294            .await
1295            .log_err();
1296    }
1297
1298    #[cfg(test)]
1299    fn as_fake(&self) -> Option<&tests::FakeDb> {
1300        None
1301    }
1302}
1303
1304macro_rules! id_type {
1305    ($name:ident) => {
1306        #[derive(
1307            Clone,
1308            Copy,
1309            Debug,
1310            Default,
1311            PartialEq,
1312            Eq,
1313            PartialOrd,
1314            Ord,
1315            Hash,
1316            sqlx::Type,
1317            Serialize,
1318            Deserialize,
1319        )]
1320        #[sqlx(transparent)]
1321        #[serde(transparent)]
1322        pub struct $name(pub i32);
1323
1324        impl $name {
1325            #[allow(unused)]
1326            pub const MAX: Self = Self(i32::MAX);
1327
1328            #[allow(unused)]
1329            pub fn from_proto(value: u64) -> Self {
1330                Self(value as i32)
1331            }
1332
1333            #[allow(unused)]
1334            pub fn to_proto(&self) -> u64 {
1335                self.0 as u64
1336            }
1337        }
1338
1339        impl std::fmt::Display for $name {
1340            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1341                self.0.fmt(f)
1342            }
1343        }
1344    };
1345}
1346
1347id_type!(UserId);
1348#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1349pub struct User {
1350    pub id: UserId,
1351    pub github_login: String,
1352    pub email_address: Option<String>,
1353    pub admin: bool,
1354    pub invite_code: Option<String>,
1355    pub invite_count: i32,
1356    pub connected_once: bool,
1357}
1358
1359id_type!(ProjectId);
1360#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1361pub struct Project {
1362    pub id: ProjectId,
1363    pub host_user_id: UserId,
1364    pub unregistered: bool,
1365}
1366
1367#[derive(Clone, Debug, PartialEq, Serialize)]
1368pub struct UserActivitySummary {
1369    pub id: UserId,
1370    pub github_login: String,
1371    pub project_activity: Vec<ProjectActivitySummary>,
1372}
1373
1374#[derive(Clone, Debug, PartialEq, Serialize)]
1375pub struct ProjectActivitySummary {
1376    id: ProjectId,
1377    duration: Duration,
1378    max_collaborators: usize,
1379}
1380
1381#[derive(Clone, Debug, PartialEq, Serialize)]
1382pub struct UserActivityPeriod {
1383    project_id: ProjectId,
1384    #[serde(with = "time::serde::iso8601")]
1385    start: OffsetDateTime,
1386    #[serde(with = "time::serde::iso8601")]
1387    end: OffsetDateTime,
1388    extensions: HashMap<String, usize>,
1389}
1390
1391id_type!(OrgId);
1392#[derive(FromRow)]
1393pub struct Org {
1394    pub id: OrgId,
1395    pub name: String,
1396    pub slug: String,
1397}
1398
1399id_type!(ChannelId);
1400#[derive(Clone, Debug, FromRow, Serialize)]
1401pub struct Channel {
1402    pub id: ChannelId,
1403    pub name: String,
1404    pub owner_id: i32,
1405    pub owner_is_user: bool,
1406}
1407
1408id_type!(MessageId);
1409#[derive(Clone, Debug, FromRow)]
1410pub struct ChannelMessage {
1411    pub id: MessageId,
1412    pub channel_id: ChannelId,
1413    pub sender_id: UserId,
1414    pub body: String,
1415    pub sent_at: OffsetDateTime,
1416    pub nonce: Uuid,
1417}
1418
1419#[derive(Clone, Debug, PartialEq, Eq)]
1420pub enum Contact {
1421    Accepted {
1422        user_id: UserId,
1423        should_notify: bool,
1424    },
1425    Outgoing {
1426        user_id: UserId,
1427    },
1428    Incoming {
1429        user_id: UserId,
1430        should_notify: bool,
1431    },
1432}
1433
1434impl Contact {
1435    pub fn user_id(&self) -> UserId {
1436        match self {
1437            Contact::Accepted { user_id, .. } => *user_id,
1438            Contact::Outgoing { user_id } => *user_id,
1439            Contact::Incoming { user_id, .. } => *user_id,
1440        }
1441    }
1442}
1443
1444#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1445pub struct IncomingContactRequest {
1446    pub requester_id: UserId,
1447    pub should_notify: bool,
1448}
1449
1450fn fuzzy_like_string(string: &str) -> String {
1451    let mut result = String::with_capacity(string.len() * 2 + 1);
1452    for c in string.chars() {
1453        if c.is_alphanumeric() {
1454            result.push('%');
1455            result.push(c);
1456        }
1457    }
1458    result.push('%');
1459    result
1460}
1461
1462fn random_invite_code() -> String {
1463    nanoid::nanoid!(16)
1464}
1465
1466#[cfg(test)]
1467pub mod tests {
1468    use super::*;
1469    use anyhow::anyhow;
1470    use collections::BTreeMap;
1471    use gpui::executor::{Background, Deterministic};
1472    use lazy_static::lazy_static;
1473    use parking_lot::Mutex;
1474    use rand::prelude::*;
1475    use sqlx::{
1476        migrate::{MigrateDatabase, Migrator},
1477        Postgres,
1478    };
1479    use std::{path::Path, sync::Arc};
1480    use util::post_inc;
1481
1482    #[tokio::test(flavor = "multi_thread")]
1483    async fn test_get_users_by_ids() {
1484        for test_db in [
1485            TestDb::postgres().await,
1486            TestDb::fake(build_background_executor()),
1487        ] {
1488            let db = test_db.db();
1489
1490            let user = db.create_user("user", None, false).await.unwrap();
1491            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1492            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1493            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1494
1495            assert_eq!(
1496                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1497                    .await
1498                    .unwrap(),
1499                vec![
1500                    User {
1501                        id: user,
1502                        github_login: "user".to_string(),
1503                        admin: false,
1504                        ..Default::default()
1505                    },
1506                    User {
1507                        id: friend1,
1508                        github_login: "friend-1".to_string(),
1509                        admin: false,
1510                        ..Default::default()
1511                    },
1512                    User {
1513                        id: friend2,
1514                        github_login: "friend-2".to_string(),
1515                        admin: false,
1516                        ..Default::default()
1517                    },
1518                    User {
1519                        id: friend3,
1520                        github_login: "friend-3".to_string(),
1521                        admin: false,
1522                        ..Default::default()
1523                    }
1524                ]
1525            );
1526        }
1527    }
1528
1529    #[tokio::test(flavor = "multi_thread")]
1530    async fn test_create_users() {
1531        let db = TestDb::postgres().await;
1532        let db = db.db();
1533
1534        // Create the first batch of users, ensuring invite counts are assigned
1535        // correctly and the respective invite codes are unique.
1536        let user_ids_batch_1 = db
1537            .create_users(vec![
1538                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1539                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1540                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1541            ])
1542            .await
1543            .unwrap();
1544        assert_eq!(user_ids_batch_1.len(), 3);
1545
1546        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1547        assert_eq!(users.len(), 3);
1548        assert_eq!(users[0].github_login, "user1");
1549        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1550        assert_eq!(users[0].invite_count, 5);
1551        assert_eq!(users[1].github_login, "user2");
1552        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1553        assert_eq!(users[1].invite_count, 4);
1554        assert_eq!(users[2].github_login, "user3");
1555        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1556        assert_eq!(users[2].invite_count, 3);
1557
1558        let invite_code_1 = users[0].invite_code.clone().unwrap();
1559        let invite_code_2 = users[1].invite_code.clone().unwrap();
1560        let invite_code_3 = users[2].invite_code.clone().unwrap();
1561        assert_ne!(invite_code_1, invite_code_2);
1562        assert_ne!(invite_code_1, invite_code_3);
1563        assert_ne!(invite_code_2, invite_code_3);
1564
1565        // Create the second batch of users and include a user that is already in the database, ensuring
1566        // the invite count for the existing user is updated without changing their invite code.
1567        let user_ids_batch_2 = db
1568            .create_users(vec![
1569                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1570                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1571            ])
1572            .await
1573            .unwrap();
1574        assert_eq!(user_ids_batch_2.len(), 2);
1575        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1576
1577        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1578        assert_eq!(users.len(), 2);
1579        assert_eq!(users[0].github_login, "user2");
1580        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1581        assert_eq!(users[0].invite_count, 10);
1582        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1583        assert_eq!(users[1].github_login, "user4");
1584        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1585        assert_eq!(users[1].invite_count, 2);
1586
1587        let invite_code_4 = users[1].invite_code.clone().unwrap();
1588        assert_ne!(invite_code_4, invite_code_1);
1589        assert_ne!(invite_code_4, invite_code_2);
1590        assert_ne!(invite_code_4, invite_code_3);
1591    }
1592
1593    #[tokio::test(flavor = "multi_thread")]
1594    async fn test_worktree_extensions() {
1595        let test_db = TestDb::postgres().await;
1596        let db = test_db.db();
1597
1598        let user = db.create_user("user_1", None, false).await.unwrap();
1599        let project = db.register_project(user).await.unwrap();
1600
1601        db.update_worktree_extensions(project, 100, Default::default())
1602            .await
1603            .unwrap();
1604        db.update_worktree_extensions(
1605            project,
1606            100,
1607            [("rs".to_string(), 5), ("md".to_string(), 3)]
1608                .into_iter()
1609                .collect(),
1610        )
1611        .await
1612        .unwrap();
1613        db.update_worktree_extensions(
1614            project,
1615            100,
1616            [("rs".to_string(), 6), ("md".to_string(), 5)]
1617                .into_iter()
1618                .collect(),
1619        )
1620        .await
1621        .unwrap();
1622        db.update_worktree_extensions(
1623            project,
1624            101,
1625            [("ts".to_string(), 2), ("md".to_string(), 1)]
1626                .into_iter()
1627                .collect(),
1628        )
1629        .await
1630        .unwrap();
1631
1632        assert_eq!(
1633            db.get_project_extensions(project).await.unwrap(),
1634            [
1635                (
1636                    100,
1637                    [("rs".into(), 6), ("md".into(), 5),]
1638                        .into_iter()
1639                        .collect::<HashMap<_, _>>()
1640                ),
1641                (
1642                    101,
1643                    [("ts".into(), 2), ("md".into(), 1),]
1644                        .into_iter()
1645                        .collect::<HashMap<_, _>>()
1646                )
1647            ]
1648            .into_iter()
1649            .collect()
1650        );
1651    }
1652
1653    #[tokio::test(flavor = "multi_thread")]
1654    async fn test_user_activity() {
1655        let test_db = TestDb::postgres().await;
1656        let db = test_db.db();
1657
1658        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1659        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1660        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1661        let project_1 = db.register_project(user_1).await.unwrap();
1662        db.update_worktree_extensions(
1663            project_1,
1664            1,
1665            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1666        )
1667        .await
1668        .unwrap();
1669        let project_2 = db.register_project(user_2).await.unwrap();
1670        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1671
1672        // User 2 opens a project
1673        let t1 = t0 + Duration::from_secs(10);
1674        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1675            .await
1676            .unwrap();
1677
1678        let t2 = t1 + Duration::from_secs(10);
1679        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1680            .await
1681            .unwrap();
1682
1683        // User 1 joins the project
1684        let t3 = t2 + Duration::from_secs(10);
1685        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1686            .await
1687            .unwrap();
1688
1689        // User 1 opens another project
1690        let t4 = t3 + Duration::from_secs(10);
1691        db.record_user_activity(
1692            t3..t4,
1693            &[
1694                (user_2, project_2),
1695                (user_1, project_2),
1696                (user_1, project_1),
1697            ],
1698        )
1699        .await
1700        .unwrap();
1701
1702        // User 3 joins that project
1703        let t5 = t4 + Duration::from_secs(10);
1704        db.record_user_activity(
1705            t4..t5,
1706            &[
1707                (user_2, project_2),
1708                (user_1, project_2),
1709                (user_1, project_1),
1710                (user_3, project_1),
1711            ],
1712        )
1713        .await
1714        .unwrap();
1715
1716        // User 2 leaves
1717        let t6 = t5 + Duration::from_secs(5);
1718        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1719            .await
1720            .unwrap();
1721
1722        let t7 = t6 + Duration::from_secs(60);
1723        let t8 = t7 + Duration::from_secs(10);
1724        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1725            .await
1726            .unwrap();
1727
1728        assert_eq!(
1729            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1730            &[
1731                UserActivitySummary {
1732                    id: user_1,
1733                    github_login: "user_1".to_string(),
1734                    project_activity: vec![
1735                        ProjectActivitySummary {
1736                            id: project_1,
1737                            duration: Duration::from_secs(25),
1738                            max_collaborators: 2
1739                        },
1740                        ProjectActivitySummary {
1741                            id: project_2,
1742                            duration: Duration::from_secs(30),
1743                            max_collaborators: 2
1744                        }
1745                    ]
1746                },
1747                UserActivitySummary {
1748                    id: user_2,
1749                    github_login: "user_2".to_string(),
1750                    project_activity: vec![ProjectActivitySummary {
1751                        id: project_2,
1752                        duration: Duration::from_secs(50),
1753                        max_collaborators: 2
1754                    }]
1755                },
1756                UserActivitySummary {
1757                    id: user_3,
1758                    github_login: "user_3".to_string(),
1759                    project_activity: vec![ProjectActivitySummary {
1760                        id: project_1,
1761                        duration: Duration::from_secs(15),
1762                        max_collaborators: 2
1763                    }]
1764                },
1765            ]
1766        );
1767
1768        assert_eq!(
1769            db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1770                .await
1771                .unwrap(),
1772            0
1773        );
1774        assert_eq!(
1775            db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1776                .await
1777                .unwrap(),
1778            0
1779        );
1780        assert_eq!(
1781            db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1782                .await
1783                .unwrap(),
1784            1
1785        );
1786        assert_eq!(
1787            db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1788                .await
1789                .unwrap(),
1790            1
1791        );
1792        assert_eq!(
1793            db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1794                .await
1795                .unwrap(),
1796            2
1797        );
1798        assert_eq!(
1799            db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1800                .await
1801                .unwrap(),
1802            2
1803        );
1804        assert_eq!(
1805            db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1806                .await
1807                .unwrap(),
1808            3
1809        );
1810        assert_eq!(
1811            db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1812                .await
1813                .unwrap(),
1814            3
1815        );
1816        assert_eq!(
1817            db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1818                .await
1819                .unwrap(),
1820            1
1821        );
1822        assert_eq!(
1823            db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1824                .await
1825                .unwrap(),
1826            0
1827        );
1828
1829        assert_eq!(
1830            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1831            &[
1832                UserActivityPeriod {
1833                    project_id: project_1,
1834                    start: t3,
1835                    end: t6,
1836                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1837                },
1838                UserActivityPeriod {
1839                    project_id: project_2,
1840                    start: t3,
1841                    end: t5,
1842                    extensions: Default::default(),
1843                },
1844            ]
1845        );
1846        assert_eq!(
1847            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1848            &[
1849                UserActivityPeriod {
1850                    project_id: project_2,
1851                    start: t2,
1852                    end: t5,
1853                    extensions: Default::default(),
1854                },
1855                UserActivityPeriod {
1856                    project_id: project_1,
1857                    start: t3,
1858                    end: t6,
1859                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1860                },
1861                UserActivityPeriod {
1862                    project_id: project_1,
1863                    start: t7,
1864                    end: t8,
1865                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1866                },
1867            ]
1868        );
1869    }
1870
1871    #[tokio::test(flavor = "multi_thread")]
1872    async fn test_recent_channel_messages() {
1873        for test_db in [
1874            TestDb::postgres().await,
1875            TestDb::fake(build_background_executor()),
1876        ] {
1877            let db = test_db.db();
1878            let user = db.create_user("user", None, false).await.unwrap();
1879            let org = db.create_org("org", "org").await.unwrap();
1880            let channel = db.create_org_channel(org, "channel").await.unwrap();
1881            for i in 0..10 {
1882                db.create_channel_message(
1883                    channel,
1884                    user,
1885                    &i.to_string(),
1886                    OffsetDateTime::now_utc(),
1887                    i,
1888                )
1889                .await
1890                .unwrap();
1891            }
1892
1893            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1894            assert_eq!(
1895                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1896                ["5", "6", "7", "8", "9"]
1897            );
1898
1899            let prev_messages = db
1900                .get_channel_messages(channel, 4, Some(messages[0].id))
1901                .await
1902                .unwrap();
1903            assert_eq!(
1904                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1905                ["1", "2", "3", "4"]
1906            );
1907        }
1908    }
1909
1910    #[tokio::test(flavor = "multi_thread")]
1911    async fn test_channel_message_nonces() {
1912        for test_db in [
1913            TestDb::postgres().await,
1914            TestDb::fake(build_background_executor()),
1915        ] {
1916            let db = test_db.db();
1917            let user = db.create_user("user", None, false).await.unwrap();
1918            let org = db.create_org("org", "org").await.unwrap();
1919            let channel = db.create_org_channel(org, "channel").await.unwrap();
1920
1921            let msg1_id = db
1922                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1923                .await
1924                .unwrap();
1925            let msg2_id = db
1926                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1927                .await
1928                .unwrap();
1929            let msg3_id = db
1930                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1931                .await
1932                .unwrap();
1933            let msg4_id = db
1934                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1935                .await
1936                .unwrap();
1937
1938            assert_ne!(msg1_id, msg2_id);
1939            assert_eq!(msg1_id, msg3_id);
1940            assert_eq!(msg2_id, msg4_id);
1941        }
1942    }
1943
1944    #[tokio::test(flavor = "multi_thread")]
1945    async fn test_create_access_tokens() {
1946        let test_db = TestDb::postgres().await;
1947        let db = test_db.db();
1948        let user = db.create_user("the-user", None, false).await.unwrap();
1949
1950        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1951        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1952        assert_eq!(
1953            db.get_access_token_hashes(user).await.unwrap(),
1954            &["h2".to_string(), "h1".to_string()]
1955        );
1956
1957        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1958        assert_eq!(
1959            db.get_access_token_hashes(user).await.unwrap(),
1960            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1961        );
1962
1963        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1964        assert_eq!(
1965            db.get_access_token_hashes(user).await.unwrap(),
1966            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1967        );
1968
1969        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1970        assert_eq!(
1971            db.get_access_token_hashes(user).await.unwrap(),
1972            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1973        );
1974    }
1975
1976    #[test]
1977    fn test_fuzzy_like_string() {
1978        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1979        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1980        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1981    }
1982
1983    #[tokio::test(flavor = "multi_thread")]
1984    async fn test_fuzzy_search_users() {
1985        let test_db = TestDb::postgres().await;
1986        let db = test_db.db();
1987        for github_login in [
1988            "California",
1989            "colorado",
1990            "oregon",
1991            "washington",
1992            "florida",
1993            "delaware",
1994            "rhode-island",
1995        ] {
1996            db.create_user(github_login, None, false).await.unwrap();
1997        }
1998
1999        assert_eq!(
2000            fuzzy_search_user_names(db, "clr").await,
2001            &["colorado", "California"]
2002        );
2003        assert_eq!(
2004            fuzzy_search_user_names(db, "ro").await,
2005            &["rhode-island", "colorado", "oregon"],
2006        );
2007
2008        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2009            db.fuzzy_search_users(query, 10)
2010                .await
2011                .unwrap()
2012                .into_iter()
2013                .map(|user| user.github_login)
2014                .collect::<Vec<_>>()
2015        }
2016    }
2017
2018    #[tokio::test(flavor = "multi_thread")]
2019    async fn test_add_contacts() {
2020        for test_db in [
2021            TestDb::postgres().await,
2022            TestDb::fake(build_background_executor()),
2023        ] {
2024            let db = test_db.db();
2025
2026            let user_1 = db.create_user("user1", None, false).await.unwrap();
2027            let user_2 = db.create_user("user2", None, false).await.unwrap();
2028            let user_3 = db.create_user("user3", None, false).await.unwrap();
2029
2030            // User starts with no contacts
2031            assert_eq!(
2032                db.get_contacts(user_1).await.unwrap(),
2033                vec![Contact::Accepted {
2034                    user_id: user_1,
2035                    should_notify: false
2036                }],
2037            );
2038
2039            // User requests a contact. Both users see the pending request.
2040            db.send_contact_request(user_1, user_2).await.unwrap();
2041            assert!(!db.has_contact(user_1, user_2).await.unwrap());
2042            assert!(!db.has_contact(user_2, user_1).await.unwrap());
2043            assert_eq!(
2044                db.get_contacts(user_1).await.unwrap(),
2045                &[
2046                    Contact::Accepted {
2047                        user_id: user_1,
2048                        should_notify: false
2049                    },
2050                    Contact::Outgoing { user_id: user_2 }
2051                ],
2052            );
2053            assert_eq!(
2054                db.get_contacts(user_2).await.unwrap(),
2055                &[
2056                    Contact::Incoming {
2057                        user_id: user_1,
2058                        should_notify: true
2059                    },
2060                    Contact::Accepted {
2061                        user_id: user_2,
2062                        should_notify: false
2063                    },
2064                ]
2065            );
2066
2067            // User 2 dismisses the contact request notification without accepting or rejecting.
2068            // We shouldn't notify them again.
2069            db.dismiss_contact_notification(user_1, user_2)
2070                .await
2071                .unwrap_err();
2072            db.dismiss_contact_notification(user_2, user_1)
2073                .await
2074                .unwrap();
2075            assert_eq!(
2076                db.get_contacts(user_2).await.unwrap(),
2077                &[
2078                    Contact::Incoming {
2079                        user_id: user_1,
2080                        should_notify: false
2081                    },
2082                    Contact::Accepted {
2083                        user_id: user_2,
2084                        should_notify: false
2085                    },
2086                ]
2087            );
2088
2089            // User can't accept their own contact request
2090            db.respond_to_contact_request(user_1, user_2, true)
2091                .await
2092                .unwrap_err();
2093
2094            // User accepts a contact request. Both users see the contact.
2095            db.respond_to_contact_request(user_2, user_1, true)
2096                .await
2097                .unwrap();
2098            assert_eq!(
2099                db.get_contacts(user_1).await.unwrap(),
2100                &[
2101                    Contact::Accepted {
2102                        user_id: user_1,
2103                        should_notify: false
2104                    },
2105                    Contact::Accepted {
2106                        user_id: user_2,
2107                        should_notify: true
2108                    }
2109                ],
2110            );
2111            assert!(db.has_contact(user_1, user_2).await.unwrap());
2112            assert!(db.has_contact(user_2, user_1).await.unwrap());
2113            assert_eq!(
2114                db.get_contacts(user_2).await.unwrap(),
2115                &[
2116                    Contact::Accepted {
2117                        user_id: user_1,
2118                        should_notify: false,
2119                    },
2120                    Contact::Accepted {
2121                        user_id: user_2,
2122                        should_notify: false,
2123                    },
2124                ]
2125            );
2126
2127            // Users cannot re-request existing contacts.
2128            db.send_contact_request(user_1, user_2).await.unwrap_err();
2129            db.send_contact_request(user_2, user_1).await.unwrap_err();
2130
2131            // Users can't dismiss notifications of them accepting other users' requests.
2132            db.dismiss_contact_notification(user_2, user_1)
2133                .await
2134                .unwrap_err();
2135            assert_eq!(
2136                db.get_contacts(user_1).await.unwrap(),
2137                &[
2138                    Contact::Accepted {
2139                        user_id: user_1,
2140                        should_notify: false
2141                    },
2142                    Contact::Accepted {
2143                        user_id: user_2,
2144                        should_notify: true,
2145                    },
2146                ]
2147            );
2148
2149            // Users can dismiss notifications of other users accepting their requests.
2150            db.dismiss_contact_notification(user_1, user_2)
2151                .await
2152                .unwrap();
2153            assert_eq!(
2154                db.get_contacts(user_1).await.unwrap(),
2155                &[
2156                    Contact::Accepted {
2157                        user_id: user_1,
2158                        should_notify: false
2159                    },
2160                    Contact::Accepted {
2161                        user_id: user_2,
2162                        should_notify: false,
2163                    },
2164                ]
2165            );
2166
2167            // Users send each other concurrent contact requests and
2168            // see that they are immediately accepted.
2169            db.send_contact_request(user_1, user_3).await.unwrap();
2170            db.send_contact_request(user_3, user_1).await.unwrap();
2171            assert_eq!(
2172                db.get_contacts(user_1).await.unwrap(),
2173                &[
2174                    Contact::Accepted {
2175                        user_id: user_1,
2176                        should_notify: false
2177                    },
2178                    Contact::Accepted {
2179                        user_id: user_2,
2180                        should_notify: false,
2181                    },
2182                    Contact::Accepted {
2183                        user_id: user_3,
2184                        should_notify: false
2185                    },
2186                ]
2187            );
2188            assert_eq!(
2189                db.get_contacts(user_3).await.unwrap(),
2190                &[
2191                    Contact::Accepted {
2192                        user_id: user_1,
2193                        should_notify: false
2194                    },
2195                    Contact::Accepted {
2196                        user_id: user_3,
2197                        should_notify: false
2198                    }
2199                ],
2200            );
2201
2202            // User declines a contact request. Both users see that it is gone.
2203            db.send_contact_request(user_2, user_3).await.unwrap();
2204            db.respond_to_contact_request(user_3, user_2, false)
2205                .await
2206                .unwrap();
2207            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2208            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2209            assert_eq!(
2210                db.get_contacts(user_2).await.unwrap(),
2211                &[
2212                    Contact::Accepted {
2213                        user_id: user_1,
2214                        should_notify: false
2215                    },
2216                    Contact::Accepted {
2217                        user_id: user_2,
2218                        should_notify: false
2219                    }
2220                ]
2221            );
2222            assert_eq!(
2223                db.get_contacts(user_3).await.unwrap(),
2224                &[
2225                    Contact::Accepted {
2226                        user_id: user_1,
2227                        should_notify: false
2228                    },
2229                    Contact::Accepted {
2230                        user_id: user_3,
2231                        should_notify: false
2232                    }
2233                ],
2234            );
2235        }
2236    }
2237
2238    #[tokio::test(flavor = "multi_thread")]
2239    async fn test_invite_codes() {
2240        let postgres = TestDb::postgres().await;
2241        let db = postgres.db();
2242        let user1 = db.create_user("user-1", None, false).await.unwrap();
2243
2244        // Initially, user 1 has no invite code
2245        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2246
2247        // Setting invite count to 0 when no code is assigned does not assign a new code
2248        db.set_invite_count(user1, 0).await.unwrap();
2249        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2250
2251        // User 1 creates an invite code that can be used twice.
2252        db.set_invite_count(user1, 2).await.unwrap();
2253        let (invite_code, invite_count) =
2254            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2255        assert_eq!(invite_count, 2);
2256
2257        // User 2 redeems the invite code and becomes a contact of user 1.
2258        let user2 = db
2259            .redeem_invite_code(&invite_code, "user-2", None)
2260            .await
2261            .unwrap();
2262        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2263        assert_eq!(invite_count, 1);
2264        assert_eq!(
2265            db.get_contacts(user1).await.unwrap(),
2266            [
2267                Contact::Accepted {
2268                    user_id: user1,
2269                    should_notify: false
2270                },
2271                Contact::Accepted {
2272                    user_id: user2,
2273                    should_notify: true
2274                }
2275            ]
2276        );
2277        assert_eq!(
2278            db.get_contacts(user2).await.unwrap(),
2279            [
2280                Contact::Accepted {
2281                    user_id: user1,
2282                    should_notify: false
2283                },
2284                Contact::Accepted {
2285                    user_id: user2,
2286                    should_notify: false
2287                }
2288            ]
2289        );
2290
2291        // User 3 redeems the invite code and becomes a contact of user 1.
2292        let user3 = db
2293            .redeem_invite_code(&invite_code, "user-3", None)
2294            .await
2295            .unwrap();
2296        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2297        assert_eq!(invite_count, 0);
2298        assert_eq!(
2299            db.get_contacts(user1).await.unwrap(),
2300            [
2301                Contact::Accepted {
2302                    user_id: user1,
2303                    should_notify: false
2304                },
2305                Contact::Accepted {
2306                    user_id: user2,
2307                    should_notify: true
2308                },
2309                Contact::Accepted {
2310                    user_id: user3,
2311                    should_notify: true
2312                }
2313            ]
2314        );
2315        assert_eq!(
2316            db.get_contacts(user3).await.unwrap(),
2317            [
2318                Contact::Accepted {
2319                    user_id: user1,
2320                    should_notify: false
2321                },
2322                Contact::Accepted {
2323                    user_id: user3,
2324                    should_notify: false
2325                },
2326            ]
2327        );
2328
2329        // Trying to reedem the code for the third time results in an error.
2330        db.redeem_invite_code(&invite_code, "user-4", None)
2331            .await
2332            .unwrap_err();
2333
2334        // Invite count can be updated after the code has been created.
2335        db.set_invite_count(user1, 2).await.unwrap();
2336        let (latest_code, invite_count) =
2337            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2338        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2339        assert_eq!(invite_count, 2);
2340
2341        // User 4 can now redeem the invite code and becomes a contact of user 1.
2342        let user4 = db
2343            .redeem_invite_code(&invite_code, "user-4", None)
2344            .await
2345            .unwrap();
2346        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2347        assert_eq!(invite_count, 1);
2348        assert_eq!(
2349            db.get_contacts(user1).await.unwrap(),
2350            [
2351                Contact::Accepted {
2352                    user_id: user1,
2353                    should_notify: false
2354                },
2355                Contact::Accepted {
2356                    user_id: user2,
2357                    should_notify: true
2358                },
2359                Contact::Accepted {
2360                    user_id: user3,
2361                    should_notify: true
2362                },
2363                Contact::Accepted {
2364                    user_id: user4,
2365                    should_notify: true
2366                }
2367            ]
2368        );
2369        assert_eq!(
2370            db.get_contacts(user4).await.unwrap(),
2371            [
2372                Contact::Accepted {
2373                    user_id: user1,
2374                    should_notify: false
2375                },
2376                Contact::Accepted {
2377                    user_id: user4,
2378                    should_notify: false
2379                },
2380            ]
2381        );
2382
2383        // An existing user cannot redeem invite codes.
2384        db.redeem_invite_code(&invite_code, "user-2", None)
2385            .await
2386            .unwrap_err();
2387        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2388        assert_eq!(invite_count, 1);
2389
2390        // Ensure invited users get invite codes too.
2391        assert_eq!(
2392            db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2393            5
2394        );
2395        assert_eq!(
2396            db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2397            5
2398        );
2399        assert_eq!(
2400            db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2401            5
2402        );
2403    }
2404
2405    pub struct TestDb {
2406        pub db: Option<Arc<dyn Db>>,
2407        pub url: String,
2408    }
2409
2410    impl TestDb {
2411        pub async fn postgres() -> Self {
2412            lazy_static! {
2413                static ref LOCK: Mutex<()> = Mutex::new(());
2414            }
2415
2416            let _guard = LOCK.lock();
2417            let mut rng = StdRng::from_entropy();
2418            let name = format!("zed-test-{}", rng.gen::<u128>());
2419            let url = format!("postgres://postgres@localhost/{}", name);
2420            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2421            Postgres::create_database(&url)
2422                .await
2423                .expect("failed to create test db");
2424            let db = PostgresDb::new(&url, 5).await.unwrap();
2425            let migrator = Migrator::new(migrations_path).await.unwrap();
2426            migrator.run(&db.pool).await.unwrap();
2427            Self {
2428                db: Some(Arc::new(db)),
2429                url,
2430            }
2431        }
2432
2433        pub fn fake(background: Arc<Background>) -> Self {
2434            Self {
2435                db: Some(Arc::new(FakeDb::new(background))),
2436                url: Default::default(),
2437            }
2438        }
2439
2440        pub fn db(&self) -> &Arc<dyn Db> {
2441            self.db.as_ref().unwrap()
2442        }
2443    }
2444
2445    impl Drop for TestDb {
2446        fn drop(&mut self) {
2447            if let Some(db) = self.db.take() {
2448                futures::executor::block_on(db.teardown(&self.url));
2449            }
2450        }
2451    }
2452
2453    pub struct FakeDb {
2454        background: Arc<Background>,
2455        pub users: Mutex<BTreeMap<UserId, User>>,
2456        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2457        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2458        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2459        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2460        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2461        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2462        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2463        pub contacts: Mutex<Vec<FakeContact>>,
2464        next_channel_message_id: Mutex<i32>,
2465        next_user_id: Mutex<i32>,
2466        next_org_id: Mutex<i32>,
2467        next_channel_id: Mutex<i32>,
2468        next_project_id: Mutex<i32>,
2469    }
2470
2471    #[derive(Debug)]
2472    pub struct FakeContact {
2473        pub requester_id: UserId,
2474        pub responder_id: UserId,
2475        pub accepted: bool,
2476        pub should_notify: bool,
2477    }
2478
2479    impl FakeDb {
2480        pub fn new(background: Arc<Background>) -> Self {
2481            Self {
2482                background,
2483                users: Default::default(),
2484                next_user_id: Mutex::new(0),
2485                projects: Default::default(),
2486                worktree_extensions: Default::default(),
2487                next_project_id: Mutex::new(1),
2488                orgs: Default::default(),
2489                next_org_id: Mutex::new(1),
2490                org_memberships: Default::default(),
2491                channels: Default::default(),
2492                next_channel_id: Mutex::new(1),
2493                channel_memberships: Default::default(),
2494                channel_messages: Default::default(),
2495                next_channel_message_id: Mutex::new(1),
2496                contacts: Default::default(),
2497            }
2498        }
2499    }
2500
2501    #[async_trait]
2502    impl Db for FakeDb {
2503        async fn create_user(
2504            &self,
2505            github_login: &str,
2506            email_address: Option<&str>,
2507            admin: bool,
2508        ) -> Result<UserId> {
2509            self.background.simulate_random_delay().await;
2510
2511            let mut users = self.users.lock();
2512            if let Some(user) = users
2513                .values()
2514                .find(|user| user.github_login == github_login)
2515            {
2516                Ok(user.id)
2517            } else {
2518                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2519                users.insert(
2520                    user_id,
2521                    User {
2522                        id: user_id,
2523                        github_login: github_login.to_string(),
2524                        email_address: email_address.map(str::to_string),
2525                        admin,
2526                        invite_code: None,
2527                        invite_count: 0,
2528                        connected_once: false,
2529                    },
2530                );
2531                Ok(user_id)
2532            }
2533        }
2534
2535        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2536            unimplemented!()
2537        }
2538
2539        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2540            unimplemented!()
2541        }
2542
2543        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2544            unimplemented!()
2545        }
2546
2547        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2548            self.background.simulate_random_delay().await;
2549            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2550        }
2551
2552        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2553            self.background.simulate_random_delay().await;
2554            let users = self.users.lock();
2555            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2556        }
2557
2558        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2559            unimplemented!()
2560        }
2561
2562        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2563            self.background.simulate_random_delay().await;
2564            Ok(self
2565                .users
2566                .lock()
2567                .values()
2568                .find(|user| user.github_login == github_login)
2569                .cloned())
2570        }
2571
2572        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2573            unimplemented!()
2574        }
2575
2576        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2577            self.background.simulate_random_delay().await;
2578            let mut users = self.users.lock();
2579            let mut user = users
2580                .get_mut(&id)
2581                .ok_or_else(|| anyhow!("user not found"))?;
2582            user.connected_once = connected_once;
2583            Ok(())
2584        }
2585
2586        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2587            unimplemented!()
2588        }
2589
2590        // invite codes
2591
2592        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2593            unimplemented!()
2594        }
2595
2596        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2597            self.background.simulate_random_delay().await;
2598            Ok(None)
2599        }
2600
2601        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2602            unimplemented!()
2603        }
2604
2605        async fn redeem_invite_code(
2606            &self,
2607            _code: &str,
2608            _login: &str,
2609            _email_address: Option<&str>,
2610        ) -> Result<UserId> {
2611            unimplemented!()
2612        }
2613
2614        // projects
2615
2616        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2617            self.background.simulate_random_delay().await;
2618            if !self.users.lock().contains_key(&host_user_id) {
2619                Err(anyhow!("no such user"))?;
2620            }
2621
2622            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2623            self.projects.lock().insert(
2624                project_id,
2625                Project {
2626                    id: project_id,
2627                    host_user_id,
2628                    unregistered: false,
2629                },
2630            );
2631            Ok(project_id)
2632        }
2633
2634        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2635            self.background.simulate_random_delay().await;
2636            self.projects
2637                .lock()
2638                .get_mut(&project_id)
2639                .ok_or_else(|| anyhow!("no such project"))?
2640                .unregistered = true;
2641            Ok(())
2642        }
2643
2644        async fn update_worktree_extensions(
2645            &self,
2646            project_id: ProjectId,
2647            worktree_id: u64,
2648            extensions: HashMap<String, u32>,
2649        ) -> Result<()> {
2650            self.background.simulate_random_delay().await;
2651            if !self.projects.lock().contains_key(&project_id) {
2652                Err(anyhow!("no such project"))?;
2653            }
2654
2655            for (extension, count) in extensions {
2656                self.worktree_extensions
2657                    .lock()
2658                    .insert((project_id, worktree_id, extension), count);
2659            }
2660
2661            Ok(())
2662        }
2663
2664        async fn get_project_extensions(
2665            &self,
2666            _project_id: ProjectId,
2667        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2668            unimplemented!()
2669        }
2670
2671        async fn record_user_activity(
2672            &self,
2673            _time_period: Range<OffsetDateTime>,
2674            _active_projects: &[(UserId, ProjectId)],
2675        ) -> Result<()> {
2676            unimplemented!()
2677        }
2678
2679        async fn get_active_user_count(
2680            &self,
2681            _time_period: Range<OffsetDateTime>,
2682            _min_duration: Duration,
2683            _only_collaborative: bool,
2684        ) -> Result<usize> {
2685            unimplemented!()
2686        }
2687
2688        async fn get_top_users_activity_summary(
2689            &self,
2690            _time_period: Range<OffsetDateTime>,
2691            _limit: usize,
2692        ) -> Result<Vec<UserActivitySummary>> {
2693            unimplemented!()
2694        }
2695
2696        async fn get_user_activity_timeline(
2697            &self,
2698            _time_period: Range<OffsetDateTime>,
2699            _user_id: UserId,
2700        ) -> Result<Vec<UserActivityPeriod>> {
2701            unimplemented!()
2702        }
2703
2704        // contacts
2705
2706        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2707            self.background.simulate_random_delay().await;
2708            let mut contacts = vec![Contact::Accepted {
2709                user_id: id,
2710                should_notify: false,
2711            }];
2712
2713            for contact in self.contacts.lock().iter() {
2714                if contact.requester_id == id {
2715                    if contact.accepted {
2716                        contacts.push(Contact::Accepted {
2717                            user_id: contact.responder_id,
2718                            should_notify: contact.should_notify,
2719                        });
2720                    } else {
2721                        contacts.push(Contact::Outgoing {
2722                            user_id: contact.responder_id,
2723                        });
2724                    }
2725                } else if contact.responder_id == id {
2726                    if contact.accepted {
2727                        contacts.push(Contact::Accepted {
2728                            user_id: contact.requester_id,
2729                            should_notify: false,
2730                        });
2731                    } else {
2732                        contacts.push(Contact::Incoming {
2733                            user_id: contact.requester_id,
2734                            should_notify: contact.should_notify,
2735                        });
2736                    }
2737                }
2738            }
2739
2740            contacts.sort_unstable_by_key(|contact| contact.user_id());
2741            Ok(contacts)
2742        }
2743
2744        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2745            self.background.simulate_random_delay().await;
2746            Ok(self.contacts.lock().iter().any(|contact| {
2747                contact.accepted
2748                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2749                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2750            }))
2751        }
2752
2753        async fn send_contact_request(
2754            &self,
2755            requester_id: UserId,
2756            responder_id: UserId,
2757        ) -> Result<()> {
2758            self.background.simulate_random_delay().await;
2759            let mut contacts = self.contacts.lock();
2760            for contact in contacts.iter_mut() {
2761                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2762                    if contact.accepted {
2763                        Err(anyhow!("contact already exists"))?;
2764                    } else {
2765                        Err(anyhow!("contact already requested"))?;
2766                    }
2767                }
2768                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2769                    if contact.accepted {
2770                        Err(anyhow!("contact already exists"))?;
2771                    } else {
2772                        contact.accepted = true;
2773                        contact.should_notify = false;
2774                        return Ok(());
2775                    }
2776                }
2777            }
2778            contacts.push(FakeContact {
2779                requester_id,
2780                responder_id,
2781                accepted: false,
2782                should_notify: true,
2783            });
2784            Ok(())
2785        }
2786
2787        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2788            self.background.simulate_random_delay().await;
2789            self.contacts.lock().retain(|contact| {
2790                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2791            });
2792            Ok(())
2793        }
2794
2795        async fn dismiss_contact_notification(
2796            &self,
2797            user_id: UserId,
2798            contact_user_id: UserId,
2799        ) -> Result<()> {
2800            self.background.simulate_random_delay().await;
2801            let mut contacts = self.contacts.lock();
2802            for contact in contacts.iter_mut() {
2803                if contact.requester_id == contact_user_id
2804                    && contact.responder_id == user_id
2805                    && !contact.accepted
2806                {
2807                    contact.should_notify = false;
2808                    return Ok(());
2809                }
2810                if contact.requester_id == user_id
2811                    && contact.responder_id == contact_user_id
2812                    && contact.accepted
2813                {
2814                    contact.should_notify = false;
2815                    return Ok(());
2816                }
2817            }
2818            Err(anyhow!("no such notification"))?
2819        }
2820
2821        async fn respond_to_contact_request(
2822            &self,
2823            responder_id: UserId,
2824            requester_id: UserId,
2825            accept: bool,
2826        ) -> Result<()> {
2827            self.background.simulate_random_delay().await;
2828            let mut contacts = self.contacts.lock();
2829            for (ix, contact) in contacts.iter_mut().enumerate() {
2830                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2831                    if contact.accepted {
2832                        Err(anyhow!("contact already confirmed"))?;
2833                    }
2834                    if accept {
2835                        contact.accepted = true;
2836                        contact.should_notify = true;
2837                    } else {
2838                        contacts.remove(ix);
2839                    }
2840                    return Ok(());
2841                }
2842            }
2843            Err(anyhow!("no such contact request"))?
2844        }
2845
2846        async fn create_access_token_hash(
2847            &self,
2848            _user_id: UserId,
2849            _access_token_hash: &str,
2850            _max_access_token_count: usize,
2851        ) -> Result<()> {
2852            unimplemented!()
2853        }
2854
2855        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2856            unimplemented!()
2857        }
2858
2859        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2860            unimplemented!()
2861        }
2862
2863        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2864            self.background.simulate_random_delay().await;
2865            let mut orgs = self.orgs.lock();
2866            if orgs.values().any(|org| org.slug == slug) {
2867                Err(anyhow!("org already exists"))?
2868            } else {
2869                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2870                orgs.insert(
2871                    org_id,
2872                    Org {
2873                        id: org_id,
2874                        name: name.to_string(),
2875                        slug: slug.to_string(),
2876                    },
2877                );
2878                Ok(org_id)
2879            }
2880        }
2881
2882        async fn add_org_member(
2883            &self,
2884            org_id: OrgId,
2885            user_id: UserId,
2886            is_admin: bool,
2887        ) -> Result<()> {
2888            self.background.simulate_random_delay().await;
2889            if !self.orgs.lock().contains_key(&org_id) {
2890                Err(anyhow!("org does not exist"))?;
2891            }
2892            if !self.users.lock().contains_key(&user_id) {
2893                Err(anyhow!("user does not exist"))?;
2894            }
2895
2896            self.org_memberships
2897                .lock()
2898                .entry((org_id, user_id))
2899                .or_insert(is_admin);
2900            Ok(())
2901        }
2902
2903        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2904            self.background.simulate_random_delay().await;
2905            if !self.orgs.lock().contains_key(&org_id) {
2906                Err(anyhow!("org does not exist"))?;
2907            }
2908
2909            let mut channels = self.channels.lock();
2910            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2911            channels.insert(
2912                channel_id,
2913                Channel {
2914                    id: channel_id,
2915                    name: name.to_string(),
2916                    owner_id: org_id.0,
2917                    owner_is_user: false,
2918                },
2919            );
2920            Ok(channel_id)
2921        }
2922
2923        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2924            self.background.simulate_random_delay().await;
2925            Ok(self
2926                .channels
2927                .lock()
2928                .values()
2929                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2930                .cloned()
2931                .collect())
2932        }
2933
2934        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2935            self.background.simulate_random_delay().await;
2936            let channels = self.channels.lock();
2937            let memberships = self.channel_memberships.lock();
2938            Ok(channels
2939                .values()
2940                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2941                .cloned()
2942                .collect())
2943        }
2944
2945        async fn can_user_access_channel(
2946            &self,
2947            user_id: UserId,
2948            channel_id: ChannelId,
2949        ) -> Result<bool> {
2950            self.background.simulate_random_delay().await;
2951            Ok(self
2952                .channel_memberships
2953                .lock()
2954                .contains_key(&(channel_id, user_id)))
2955        }
2956
2957        async fn add_channel_member(
2958            &self,
2959            channel_id: ChannelId,
2960            user_id: UserId,
2961            is_admin: bool,
2962        ) -> Result<()> {
2963            self.background.simulate_random_delay().await;
2964            if !self.channels.lock().contains_key(&channel_id) {
2965                Err(anyhow!("channel does not exist"))?;
2966            }
2967            if !self.users.lock().contains_key(&user_id) {
2968                Err(anyhow!("user does not exist"))?;
2969            }
2970
2971            self.channel_memberships
2972                .lock()
2973                .entry((channel_id, user_id))
2974                .or_insert(is_admin);
2975            Ok(())
2976        }
2977
2978        async fn create_channel_message(
2979            &self,
2980            channel_id: ChannelId,
2981            sender_id: UserId,
2982            body: &str,
2983            timestamp: OffsetDateTime,
2984            nonce: u128,
2985        ) -> Result<MessageId> {
2986            self.background.simulate_random_delay().await;
2987            if !self.channels.lock().contains_key(&channel_id) {
2988                Err(anyhow!("channel does not exist"))?;
2989            }
2990            if !self.users.lock().contains_key(&sender_id) {
2991                Err(anyhow!("user does not exist"))?;
2992            }
2993
2994            let mut messages = self.channel_messages.lock();
2995            if let Some(message) = messages
2996                .values()
2997                .find(|message| message.nonce.as_u128() == nonce)
2998            {
2999                Ok(message.id)
3000            } else {
3001                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
3002                messages.insert(
3003                    message_id,
3004                    ChannelMessage {
3005                        id: message_id,
3006                        channel_id,
3007                        sender_id,
3008                        body: body.to_string(),
3009                        sent_at: timestamp,
3010                        nonce: Uuid::from_u128(nonce),
3011                    },
3012                );
3013                Ok(message_id)
3014            }
3015        }
3016
3017        async fn get_channel_messages(
3018            &self,
3019            channel_id: ChannelId,
3020            count: usize,
3021            before_id: Option<MessageId>,
3022        ) -> Result<Vec<ChannelMessage>> {
3023            self.background.simulate_random_delay().await;
3024            let mut messages = self
3025                .channel_messages
3026                .lock()
3027                .values()
3028                .rev()
3029                .filter(|message| {
3030                    message.channel_id == channel_id
3031                        && message.id < before_id.unwrap_or(MessageId::MAX)
3032                })
3033                .take(count)
3034                .cloned()
3035                .collect::<Vec<_>>();
3036            messages.sort_unstable_by_key(|message| message.id);
3037            Ok(messages)
3038        }
3039
3040        async fn teardown(&self, _: &str) {}
3041
3042        #[cfg(test)]
3043        fn as_fake(&self) -> Option<&FakeDb> {
3044            Some(self)
3045        }
3046    }
3047
3048    fn build_background_executor() -> Arc<Background> {
3049        Deterministic::new(0).build_background()
3050    }
3051}