db.rs

   1use std::{cmp, ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use nanoid::nanoid;
  10use serde::{Deserialize, Serialize};
  11pub use sqlx::postgres::PgPoolOptions as DbOptions;
  12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[async_trait]
  16pub trait Db: Send + Sync {
  17    async fn create_user(
  18        &self,
  19        github_login: &str,
  20        email_address: Option<&str>,
  21        admin: bool,
  22    ) -> Result<UserId>;
  23    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  24    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  28    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  29    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  30    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  31    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33
  34    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  35    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  36    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  37    async fn redeem_invite_code(
  38        &self,
  39        code: &str,
  40        login: &str,
  41        email_address: Option<&str>,
  42    ) -> Result<UserId>;
  43
  44    /// Registers a new project for the given user.
  45    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  46
  47    /// Unregisters a project for the given project id.
  48    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  49
  50    /// Update file counts by extension for the given project and worktree.
  51    async fn update_worktree_extensions(
  52        &self,
  53        project_id: ProjectId,
  54        worktree_id: u64,
  55        extensions: HashMap<String, u32>,
  56    ) -> Result<()>;
  57
  58    /// Get the file counts on the given project keyed by their worktree and extension.
  59    async fn get_project_extensions(
  60        &self,
  61        project_id: ProjectId,
  62    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  63
  64    /// Record which users have been active in which projects during
  65    /// a given period of time.
  66    async fn record_user_activity(
  67        &self,
  68        time_period: Range<OffsetDateTime>,
  69        active_projects: &[(UserId, ProjectId)],
  70    ) -> Result<()>;
  71
  72    /// Get the number of users who have been active in the given
  73    /// time period for at least the given time duration.
  74    async fn get_active_user_count(
  75        &self,
  76        time_period: Range<OffsetDateTime>,
  77        min_duration: Duration,
  78    ) -> Result<usize>;
  79
  80    /// Get the users that have been most active during the given time period,
  81    /// along with the amount of time they have been active in each project.
  82    async fn get_top_users_activity_summary(
  83        &self,
  84        time_period: Range<OffsetDateTime>,
  85        max_user_count: usize,
  86    ) -> Result<Vec<UserActivitySummary>>;
  87
  88    /// Get the project activity for the given user and time period.
  89    async fn get_user_activity_timeline(
  90        &self,
  91        time_period: Range<OffsetDateTime>,
  92        user_id: UserId,
  93    ) -> Result<Vec<UserActivityPeriod>>;
  94
  95    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  96    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  97    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  98    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  99    async fn dismiss_contact_notification(
 100        &self,
 101        responder_id: UserId,
 102        requester_id: UserId,
 103    ) -> Result<()>;
 104    async fn respond_to_contact_request(
 105        &self,
 106        responder_id: UserId,
 107        requester_id: UserId,
 108        accept: bool,
 109    ) -> Result<()>;
 110
 111    async fn create_access_token_hash(
 112        &self,
 113        user_id: UserId,
 114        access_token_hash: &str,
 115        max_access_token_count: usize,
 116    ) -> Result<()>;
 117    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 118    #[cfg(any(test, feature = "seed-support"))]
 119
 120    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 121    #[cfg(any(test, feature = "seed-support"))]
 122    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 123    #[cfg(any(test, feature = "seed-support"))]
 124    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 127    #[cfg(any(test, feature = "seed-support"))]
 128
 129    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 130    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 131    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 132        -> Result<bool>;
 133    #[cfg(any(test, feature = "seed-support"))]
 134    async fn add_channel_member(
 135        &self,
 136        channel_id: ChannelId,
 137        user_id: UserId,
 138        is_admin: bool,
 139    ) -> Result<()>;
 140    async fn create_channel_message(
 141        &self,
 142        channel_id: ChannelId,
 143        sender_id: UserId,
 144        body: &str,
 145        timestamp: OffsetDateTime,
 146        nonce: u128,
 147    ) -> Result<MessageId>;
 148    async fn get_channel_messages(
 149        &self,
 150        channel_id: ChannelId,
 151        count: usize,
 152        before_id: Option<MessageId>,
 153    ) -> Result<Vec<ChannelMessage>>;
 154    #[cfg(test)]
 155    async fn teardown(&self, url: &str);
 156    #[cfg(test)]
 157    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 158}
 159
 160pub struct PostgresDb {
 161    pool: sqlx::PgPool,
 162}
 163
 164impl PostgresDb {
 165    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 166        let pool = DbOptions::new()
 167            .max_connections(max_connections)
 168            .connect(&url)
 169            .await
 170            .context("failed to connect to postgres database")?;
 171        Ok(Self { pool })
 172    }
 173}
 174
 175#[async_trait]
 176impl Db for PostgresDb {
 177    // users
 178
 179    async fn create_user(
 180        &self,
 181        github_login: &str,
 182        email_address: Option<&str>,
 183        admin: bool,
 184    ) -> Result<UserId> {
 185        let query = "
 186            INSERT INTO users (github_login, email_address, admin)
 187            VALUES ($1, $2, $3)
 188            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 189            RETURNING id
 190        ";
 191        Ok(sqlx::query_scalar(query)
 192            .bind(github_login)
 193            .bind(email_address)
 194            .bind(admin)
 195            .fetch_one(&self.pool)
 196            .await
 197            .map(UserId)?)
 198    }
 199
 200    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 201        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 202        Ok(sqlx::query_as(query)
 203            .bind(limit as i32)
 204            .bind((page * limit) as i32)
 205            .fetch_all(&self.pool)
 206            .await?)
 207    }
 208
 209    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 210        let mut query = QueryBuilder::new(
 211            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 212        );
 213        query.push_values(
 214            users,
 215            |mut query, (github_login, email_address, invite_count)| {
 216                query
 217                    .push_bind(github_login)
 218                    .push_bind(email_address)
 219                    .push_bind(false)
 220                    .push_bind(nanoid!(16))
 221                    .push_bind(invite_count as i32);
 222            },
 223        );
 224        query.push(
 225            "
 226            ON CONFLICT (github_login) DO UPDATE SET
 227                github_login = excluded.github_login,
 228                invite_count = excluded.invite_count,
 229                invite_code = CASE WHEN users.invite_code IS NULL
 230                                   THEN excluded.invite_code
 231                                   ELSE users.invite_code
 232                              END
 233            RETURNING id
 234            ",
 235        );
 236
 237        let rows = query.build().fetch_all(&self.pool).await?;
 238        Ok(rows
 239            .into_iter()
 240            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 241            .collect())
 242    }
 243
 244    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 245        let like_string = fuzzy_like_string(name_query);
 246        let query = "
 247            SELECT users.*
 248            FROM users
 249            WHERE github_login ILIKE $1
 250            ORDER BY github_login <-> $2
 251            LIMIT $3
 252        ";
 253        Ok(sqlx::query_as(query)
 254            .bind(like_string)
 255            .bind(name_query)
 256            .bind(limit as i32)
 257            .fetch_all(&self.pool)
 258            .await?)
 259    }
 260
 261    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 262        let users = self.get_users_by_ids(vec![id]).await?;
 263        Ok(users.into_iter().next())
 264    }
 265
 266    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 267        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 268        let query = "
 269            SELECT users.*
 270            FROM users
 271            WHERE users.id = ANY ($1)
 272        ";
 273        Ok(sqlx::query_as(query)
 274            .bind(&ids)
 275            .fetch_all(&self.pool)
 276            .await?)
 277    }
 278
 279    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 280        let query = format!(
 281            "
 282            SELECT users.*
 283            FROM users
 284            WHERE invite_count = 0
 285            AND inviter_id IS{} NULL
 286            ",
 287            if invited_by_another_user { " NOT" } else { "" }
 288        );
 289
 290        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 291    }
 292
 293    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 294        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 295        Ok(sqlx::query_as(query)
 296            .bind(github_login)
 297            .fetch_optional(&self.pool)
 298            .await?)
 299    }
 300
 301    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 302        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 303        Ok(sqlx::query(query)
 304            .bind(is_admin)
 305            .bind(id.0)
 306            .execute(&self.pool)
 307            .await
 308            .map(drop)?)
 309    }
 310
 311    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 312        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 313        Ok(sqlx::query(query)
 314            .bind(connected_once)
 315            .bind(id.0)
 316            .execute(&self.pool)
 317            .await
 318            .map(drop)?)
 319    }
 320
 321    async fn destroy_user(&self, id: UserId) -> Result<()> {
 322        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 323        sqlx::query(query)
 324            .bind(id.0)
 325            .execute(&self.pool)
 326            .await
 327            .map(drop)?;
 328        let query = "DELETE FROM users WHERE id = $1;";
 329        Ok(sqlx::query(query)
 330            .bind(id.0)
 331            .execute(&self.pool)
 332            .await
 333            .map(drop)?)
 334    }
 335
 336    // invite codes
 337
 338    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 339        let mut tx = self.pool.begin().await?;
 340        if count > 0 {
 341            sqlx::query(
 342                "
 343                UPDATE users
 344                SET invite_code = $1
 345                WHERE id = $2 AND invite_code IS NULL
 346            ",
 347            )
 348            .bind(nanoid!(16))
 349            .bind(id)
 350            .execute(&mut tx)
 351            .await?;
 352        }
 353
 354        sqlx::query(
 355            "
 356            UPDATE users
 357            SET invite_count = $1
 358            WHERE id = $2
 359            ",
 360        )
 361        .bind(count as i32)
 362        .bind(id)
 363        .execute(&mut tx)
 364        .await?;
 365        tx.commit().await?;
 366        Ok(())
 367    }
 368
 369    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 370        let result: Option<(String, i32)> = sqlx::query_as(
 371            "
 372                SELECT invite_code, invite_count
 373                FROM users
 374                WHERE id = $1 AND invite_code IS NOT NULL 
 375            ",
 376        )
 377        .bind(id)
 378        .fetch_optional(&self.pool)
 379        .await?;
 380        if let Some((code, count)) = result {
 381            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 382        } else {
 383            Ok(None)
 384        }
 385    }
 386
 387    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 388        sqlx::query_as(
 389            "
 390                SELECT *
 391                FROM users
 392                WHERE invite_code = $1
 393            ",
 394        )
 395        .bind(code)
 396        .fetch_optional(&self.pool)
 397        .await?
 398        .ok_or_else(|| {
 399            Error::Http(
 400                StatusCode::NOT_FOUND,
 401                "that invite code does not exist".to_string(),
 402            )
 403        })
 404    }
 405
 406    async fn redeem_invite_code(
 407        &self,
 408        code: &str,
 409        login: &str,
 410        email_address: Option<&str>,
 411    ) -> Result<UserId> {
 412        let mut tx = self.pool.begin().await?;
 413
 414        let inviter_id: Option<UserId> = sqlx::query_scalar(
 415            "
 416                UPDATE users
 417                SET invite_count = invite_count - 1
 418                WHERE
 419                    invite_code = $1 AND
 420                    invite_count > 0
 421                RETURNING id
 422            ",
 423        )
 424        .bind(code)
 425        .fetch_optional(&mut tx)
 426        .await?;
 427
 428        let inviter_id = match inviter_id {
 429            Some(inviter_id) => inviter_id,
 430            None => {
 431                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 432                    .bind(code)
 433                    .fetch_optional(&mut tx)
 434                    .await?
 435                    .is_some()
 436                {
 437                    Err(Error::Http(
 438                        StatusCode::UNAUTHORIZED,
 439                        "no invites remaining".to_string(),
 440                    ))?
 441                } else {
 442                    Err(Error::Http(
 443                        StatusCode::NOT_FOUND,
 444                        "invite code not found".to_string(),
 445                    ))?
 446                }
 447            }
 448        };
 449
 450        let invitee_id = sqlx::query_scalar(
 451            "
 452                INSERT INTO users
 453                    (github_login, email_address, admin, inviter_id)
 454                VALUES
 455                    ($1, $2, 'f', $3)
 456                RETURNING id
 457            ",
 458        )
 459        .bind(login)
 460        .bind(email_address)
 461        .bind(inviter_id)
 462        .fetch_one(&mut tx)
 463        .await
 464        .map(UserId)?;
 465
 466        sqlx::query(
 467            "
 468                INSERT INTO contacts
 469                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 470                VALUES
 471                    ($1, $2, 't', 't', 't')
 472            ",
 473        )
 474        .bind(inviter_id)
 475        .bind(invitee_id)
 476        .execute(&mut tx)
 477        .await?;
 478
 479        tx.commit().await?;
 480        Ok(invitee_id)
 481    }
 482
 483    // projects
 484
 485    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 486        Ok(sqlx::query_scalar(
 487            "
 488            INSERT INTO projects(host_user_id)
 489            VALUES ($1)
 490            RETURNING id
 491            ",
 492        )
 493        .bind(host_user_id)
 494        .fetch_one(&self.pool)
 495        .await
 496        .map(ProjectId)?)
 497    }
 498
 499    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 500        sqlx::query(
 501            "
 502            UPDATE projects
 503            SET unregistered = 't'
 504            WHERE id = $1
 505            ",
 506        )
 507        .bind(project_id)
 508        .execute(&self.pool)
 509        .await?;
 510        Ok(())
 511    }
 512
 513    async fn update_worktree_extensions(
 514        &self,
 515        project_id: ProjectId,
 516        worktree_id: u64,
 517        extensions: HashMap<String, u32>,
 518    ) -> Result<()> {
 519        if extensions.is_empty() {
 520            return Ok(());
 521        }
 522
 523        let mut query = QueryBuilder::new(
 524            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 525        );
 526        query.push_values(extensions, |mut query, (extension, count)| {
 527            query
 528                .push_bind(project_id)
 529                .push_bind(worktree_id as i32)
 530                .push_bind(extension)
 531                .push_bind(count as i32);
 532        });
 533        query.push(
 534            "
 535            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 536            count = excluded.count
 537            ",
 538        );
 539        query.build().execute(&self.pool).await?;
 540
 541        Ok(())
 542    }
 543
 544    async fn get_project_extensions(
 545        &self,
 546        project_id: ProjectId,
 547    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 548        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 549        struct WorktreeExtension {
 550            worktree_id: i32,
 551            extension: String,
 552            count: i32,
 553        }
 554
 555        let query = "
 556            SELECT worktree_id, extension, count
 557            FROM worktree_extensions
 558            WHERE project_id = $1
 559        ";
 560        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 561            .bind(&project_id)
 562            .fetch_all(&self.pool)
 563            .await?;
 564
 565        let mut extension_counts = HashMap::default();
 566        for count in counts {
 567            extension_counts
 568                .entry(count.worktree_id as u64)
 569                .or_insert(HashMap::default())
 570                .insert(count.extension, count.count as usize);
 571        }
 572        Ok(extension_counts)
 573    }
 574
 575    async fn record_user_activity(
 576        &self,
 577        time_period: Range<OffsetDateTime>,
 578        projects: &[(UserId, ProjectId)],
 579    ) -> Result<()> {
 580        let query = "
 581            INSERT INTO project_activity_periods
 582            (ended_at, duration_millis, user_id, project_id)
 583            VALUES
 584            ($1, $2, $3, $4);
 585        ";
 586
 587        let mut tx = self.pool.begin().await?;
 588        let duration_millis =
 589            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 590        for (user_id, project_id) in projects {
 591            sqlx::query(query)
 592                .bind(time_period.end)
 593                .bind(duration_millis)
 594                .bind(user_id)
 595                .bind(project_id)
 596                .execute(&mut tx)
 597                .await?;
 598        }
 599        tx.commit().await?;
 600
 601        Ok(())
 602    }
 603
 604    async fn get_active_user_count(
 605        &self,
 606        time_period: Range<OffsetDateTime>,
 607        min_duration: Duration,
 608    ) -> Result<usize> {
 609        let query = "
 610            WITH
 611                project_durations AS (
 612                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 613                    FROM project_activity_periods
 614                    WHERE $1 < ended_at AND ended_at <= $2
 615                    GROUP BY user_id, project_id
 616                ),
 617                user_durations AS (
 618                    SELECT user_id, SUM(project_duration) as total_duration
 619                    FROM project_durations
 620                    GROUP BY user_id
 621                    ORDER BY total_duration DESC
 622                    LIMIT $3
 623                )
 624            SELECT count(user_durations.user_id)
 625            FROM user_durations
 626            WHERE user_durations.total_duration >= $3
 627        ";
 628
 629        let count: i64 = sqlx::query_scalar(query)
 630            .bind(time_period.start)
 631            .bind(time_period.end)
 632            .bind(min_duration.as_millis() as i64)
 633            .fetch_one(&self.pool)
 634            .await?;
 635        Ok(count as usize)
 636    }
 637
 638    async fn get_top_users_activity_summary(
 639        &self,
 640        time_period: Range<OffsetDateTime>,
 641        max_user_count: usize,
 642    ) -> Result<Vec<UserActivitySummary>> {
 643        let query = "
 644            WITH
 645                project_durations AS (
 646                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 647                    FROM project_activity_periods
 648                    WHERE $1 < ended_at AND ended_at <= $2
 649                    GROUP BY user_id, project_id
 650                ),
 651                user_durations AS (
 652                    SELECT user_id, SUM(project_duration) as total_duration
 653                    FROM project_durations
 654                    GROUP BY user_id
 655                    ORDER BY total_duration DESC
 656                    LIMIT $3
 657                )
 658            SELECT user_durations.user_id, users.github_login, project_id, project_duration
 659            FROM user_durations, project_durations, users
 660            WHERE
 661                user_durations.user_id = project_durations.user_id AND
 662                user_durations.user_id = users.id
 663            ORDER BY total_duration DESC, user_id ASC
 664        ";
 665
 666        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
 667            .bind(time_period.start)
 668            .bind(time_period.end)
 669            .bind(max_user_count as i32)
 670            .fetch(&self.pool);
 671
 672        let mut result = Vec::<UserActivitySummary>::new();
 673        while let Some(row) = rows.next().await {
 674            let (user_id, github_login, project_id, duration_millis) = row?;
 675            let project_id = project_id;
 676            let duration = Duration::from_millis(duration_millis as u64);
 677            if let Some(last_summary) = result.last_mut() {
 678                if last_summary.id == user_id {
 679                    last_summary.project_activity.push((project_id, duration));
 680                    continue;
 681                }
 682            }
 683            result.push(UserActivitySummary {
 684                id: user_id,
 685                project_activity: vec![(project_id, duration)],
 686                github_login,
 687            });
 688        }
 689
 690        Ok(result)
 691    }
 692
 693    async fn get_user_activity_timeline(
 694        &self,
 695        time_period: Range<OffsetDateTime>,
 696        user_id: UserId,
 697    ) -> Result<Vec<UserActivityPeriod>> {
 698        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 699
 700        let query = "
 701            SELECT
 702                project_activity_periods.ended_at,
 703                project_activity_periods.duration_millis,
 704                project_activity_periods.project_id,
 705                worktree_extensions.extension,
 706                worktree_extensions.count
 707            FROM project_activity_periods
 708            LEFT OUTER JOIN
 709                worktree_extensions
 710            ON
 711                project_activity_periods.project_id = worktree_extensions.project_id
 712            WHERE
 713                project_activity_periods.user_id = $1 AND
 714                $2 < project_activity_periods.ended_at AND
 715                project_activity_periods.ended_at <= $3
 716            ORDER BY project_activity_periods.id ASC
 717        ";
 718
 719        let mut rows = sqlx::query_as::<
 720            _,
 721            (
 722                PrimitiveDateTime,
 723                i32,
 724                ProjectId,
 725                Option<String>,
 726                Option<i32>,
 727            ),
 728        >(query)
 729        .bind(user_id)
 730        .bind(time_period.start)
 731        .bind(time_period.end)
 732        .fetch(&self.pool);
 733
 734        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 735        while let Some(row) = rows.next().await {
 736            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 737            let ended_at = ended_at.assume_utc();
 738            let duration = Duration::from_millis(duration_millis as u64);
 739            let started_at = ended_at - duration;
 740            let project_time_periods = time_periods.entry(project_id).or_default();
 741
 742            if let Some(prev_duration) = project_time_periods.last_mut() {
 743                if started_at <= prev_duration.end + COALESCE_THRESHOLD
 744                    && ended_at >= prev_duration.start
 745                {
 746                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
 747                } else {
 748                    project_time_periods.push(UserActivityPeriod {
 749                        project_id,
 750                        start: started_at,
 751                        end: ended_at,
 752                        extensions: Default::default(),
 753                    });
 754                }
 755            } else {
 756                project_time_periods.push(UserActivityPeriod {
 757                    project_id,
 758                    start: started_at,
 759                    end: ended_at,
 760                    extensions: Default::default(),
 761                });
 762            }
 763
 764            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 765                project_time_periods
 766                    .last_mut()
 767                    .unwrap()
 768                    .extensions
 769                    .insert(extension, extension_count as usize);
 770            }
 771        }
 772
 773        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 774        durations.sort_unstable_by_key(|duration| duration.start);
 775        Ok(durations)
 776    }
 777
 778    // contacts
 779
 780    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 781        let query = "
 782            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 783            FROM contacts
 784            WHERE user_id_a = $1 OR user_id_b = $1;
 785        ";
 786
 787        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 788            .bind(user_id)
 789            .fetch(&self.pool);
 790
 791        let mut contacts = vec![Contact::Accepted {
 792            user_id,
 793            should_notify: false,
 794        }];
 795        while let Some(row) = rows.next().await {
 796            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 797
 798            if user_id_a == user_id {
 799                if accepted {
 800                    contacts.push(Contact::Accepted {
 801                        user_id: user_id_b,
 802                        should_notify: should_notify && a_to_b,
 803                    });
 804                } else if a_to_b {
 805                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 806                } else {
 807                    contacts.push(Contact::Incoming {
 808                        user_id: user_id_b,
 809                        should_notify,
 810                    });
 811                }
 812            } else {
 813                if accepted {
 814                    contacts.push(Contact::Accepted {
 815                        user_id: user_id_a,
 816                        should_notify: should_notify && !a_to_b,
 817                    });
 818                } else if a_to_b {
 819                    contacts.push(Contact::Incoming {
 820                        user_id: user_id_a,
 821                        should_notify,
 822                    });
 823                } else {
 824                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 825                }
 826            }
 827        }
 828
 829        contacts.sort_unstable_by_key(|contact| contact.user_id());
 830
 831        Ok(contacts)
 832    }
 833
 834    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 835        let (id_a, id_b) = if user_id_1 < user_id_2 {
 836            (user_id_1, user_id_2)
 837        } else {
 838            (user_id_2, user_id_1)
 839        };
 840
 841        let query = "
 842            SELECT 1 FROM contacts
 843            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 844            LIMIT 1
 845        ";
 846        Ok(sqlx::query_scalar::<_, i32>(query)
 847            .bind(id_a.0)
 848            .bind(id_b.0)
 849            .fetch_optional(&self.pool)
 850            .await?
 851            .is_some())
 852    }
 853
 854    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 855        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 856            (sender_id, receiver_id, true)
 857        } else {
 858            (receiver_id, sender_id, false)
 859        };
 860        let query = "
 861            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 862            VALUES ($1, $2, $3, 'f', 't')
 863            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 864            SET
 865                accepted = 't',
 866                should_notify = 'f'
 867            WHERE
 868                NOT contacts.accepted AND
 869                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 870                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 871        ";
 872        let result = sqlx::query(query)
 873            .bind(id_a.0)
 874            .bind(id_b.0)
 875            .bind(a_to_b)
 876            .execute(&self.pool)
 877            .await?;
 878
 879        if result.rows_affected() == 1 {
 880            Ok(())
 881        } else {
 882            Err(anyhow!("contact already requested"))?
 883        }
 884    }
 885
 886    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 887        let (id_a, id_b) = if responder_id < requester_id {
 888            (responder_id, requester_id)
 889        } else {
 890            (requester_id, responder_id)
 891        };
 892        let query = "
 893            DELETE FROM contacts
 894            WHERE user_id_a = $1 AND user_id_b = $2;
 895        ";
 896        let result = sqlx::query(query)
 897            .bind(id_a.0)
 898            .bind(id_b.0)
 899            .execute(&self.pool)
 900            .await?;
 901
 902        if result.rows_affected() == 1 {
 903            Ok(())
 904        } else {
 905            Err(anyhow!("no such contact"))?
 906        }
 907    }
 908
 909    async fn dismiss_contact_notification(
 910        &self,
 911        user_id: UserId,
 912        contact_user_id: UserId,
 913    ) -> Result<()> {
 914        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 915            (user_id, contact_user_id, true)
 916        } else {
 917            (contact_user_id, user_id, false)
 918        };
 919
 920        let query = "
 921            UPDATE contacts
 922            SET should_notify = 'f'
 923            WHERE
 924                user_id_a = $1 AND user_id_b = $2 AND
 925                (
 926                    (a_to_b = $3 AND accepted) OR
 927                    (a_to_b != $3 AND NOT accepted)
 928                );
 929        ";
 930
 931        let result = sqlx::query(query)
 932            .bind(id_a.0)
 933            .bind(id_b.0)
 934            .bind(a_to_b)
 935            .execute(&self.pool)
 936            .await?;
 937
 938        if result.rows_affected() == 0 {
 939            Err(anyhow!("no such contact request"))?;
 940        }
 941
 942        Ok(())
 943    }
 944
 945    async fn respond_to_contact_request(
 946        &self,
 947        responder_id: UserId,
 948        requester_id: UserId,
 949        accept: bool,
 950    ) -> Result<()> {
 951        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 952            (responder_id, requester_id, false)
 953        } else {
 954            (requester_id, responder_id, true)
 955        };
 956        let result = if accept {
 957            let query = "
 958                UPDATE contacts
 959                SET accepted = 't', should_notify = 't'
 960                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 961            ";
 962            sqlx::query(query)
 963                .bind(id_a.0)
 964                .bind(id_b.0)
 965                .bind(a_to_b)
 966                .execute(&self.pool)
 967                .await?
 968        } else {
 969            let query = "
 970                DELETE FROM contacts
 971                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 972            ";
 973            sqlx::query(query)
 974                .bind(id_a.0)
 975                .bind(id_b.0)
 976                .bind(a_to_b)
 977                .execute(&self.pool)
 978                .await?
 979        };
 980        if result.rows_affected() == 1 {
 981            Ok(())
 982        } else {
 983            Err(anyhow!("no such contact request"))?
 984        }
 985    }
 986
 987    // access tokens
 988
 989    async fn create_access_token_hash(
 990        &self,
 991        user_id: UserId,
 992        access_token_hash: &str,
 993        max_access_token_count: usize,
 994    ) -> Result<()> {
 995        let insert_query = "
 996            INSERT INTO access_tokens (user_id, hash)
 997            VALUES ($1, $2);
 998        ";
 999        let cleanup_query = "
1000            DELETE FROM access_tokens
1001            WHERE id IN (
1002                SELECT id from access_tokens
1003                WHERE user_id = $1
1004                ORDER BY id DESC
1005                OFFSET $3
1006            )
1007        ";
1008
1009        let mut tx = self.pool.begin().await?;
1010        sqlx::query(insert_query)
1011            .bind(user_id.0)
1012            .bind(access_token_hash)
1013            .execute(&mut tx)
1014            .await?;
1015        sqlx::query(cleanup_query)
1016            .bind(user_id.0)
1017            .bind(access_token_hash)
1018            .bind(max_access_token_count as i32)
1019            .execute(&mut tx)
1020            .await?;
1021        Ok(tx.commit().await?)
1022    }
1023
1024    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1025        let query = "
1026            SELECT hash
1027            FROM access_tokens
1028            WHERE user_id = $1
1029            ORDER BY id DESC
1030        ";
1031        Ok(sqlx::query_scalar(query)
1032            .bind(user_id.0)
1033            .fetch_all(&self.pool)
1034            .await?)
1035    }
1036
1037    // orgs
1038
1039    #[allow(unused)] // Help rust-analyzer
1040    #[cfg(any(test, feature = "seed-support"))]
1041    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1042        let query = "
1043            SELECT *
1044            FROM orgs
1045            WHERE slug = $1
1046        ";
1047        Ok(sqlx::query_as(query)
1048            .bind(slug)
1049            .fetch_optional(&self.pool)
1050            .await?)
1051    }
1052
1053    #[cfg(any(test, feature = "seed-support"))]
1054    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1055        let query = "
1056            INSERT INTO orgs (name, slug)
1057            VALUES ($1, $2)
1058            RETURNING id
1059        ";
1060        Ok(sqlx::query_scalar(query)
1061            .bind(name)
1062            .bind(slug)
1063            .fetch_one(&self.pool)
1064            .await
1065            .map(OrgId)?)
1066    }
1067
1068    #[cfg(any(test, feature = "seed-support"))]
1069    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1070        let query = "
1071            INSERT INTO org_memberships (org_id, user_id, admin)
1072            VALUES ($1, $2, $3)
1073            ON CONFLICT DO NOTHING
1074        ";
1075        Ok(sqlx::query(query)
1076            .bind(org_id.0)
1077            .bind(user_id.0)
1078            .bind(is_admin)
1079            .execute(&self.pool)
1080            .await
1081            .map(drop)?)
1082    }
1083
1084    // channels
1085
1086    #[cfg(any(test, feature = "seed-support"))]
1087    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1088        let query = "
1089            INSERT INTO channels (owner_id, owner_is_user, name)
1090            VALUES ($1, false, $2)
1091            RETURNING id
1092        ";
1093        Ok(sqlx::query_scalar(query)
1094            .bind(org_id.0)
1095            .bind(name)
1096            .fetch_one(&self.pool)
1097            .await
1098            .map(ChannelId)?)
1099    }
1100
1101    #[allow(unused)] // Help rust-analyzer
1102    #[cfg(any(test, feature = "seed-support"))]
1103    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1104        let query = "
1105            SELECT *
1106            FROM channels
1107            WHERE
1108                channels.owner_is_user = false AND
1109                channels.owner_id = $1
1110        ";
1111        Ok(sqlx::query_as(query)
1112            .bind(org_id.0)
1113            .fetch_all(&self.pool)
1114            .await?)
1115    }
1116
1117    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1118        let query = "
1119            SELECT
1120                channels.*
1121            FROM
1122                channel_memberships, channels
1123            WHERE
1124                channel_memberships.user_id = $1 AND
1125                channel_memberships.channel_id = channels.id
1126        ";
1127        Ok(sqlx::query_as(query)
1128            .bind(user_id.0)
1129            .fetch_all(&self.pool)
1130            .await?)
1131    }
1132
1133    async fn can_user_access_channel(
1134        &self,
1135        user_id: UserId,
1136        channel_id: ChannelId,
1137    ) -> Result<bool> {
1138        let query = "
1139            SELECT id
1140            FROM channel_memberships
1141            WHERE user_id = $1 AND channel_id = $2
1142            LIMIT 1
1143        ";
1144        Ok(sqlx::query_scalar::<_, i32>(query)
1145            .bind(user_id.0)
1146            .bind(channel_id.0)
1147            .fetch_optional(&self.pool)
1148            .await
1149            .map(|e| e.is_some())?)
1150    }
1151
1152    #[cfg(any(test, feature = "seed-support"))]
1153    async fn add_channel_member(
1154        &self,
1155        channel_id: ChannelId,
1156        user_id: UserId,
1157        is_admin: bool,
1158    ) -> Result<()> {
1159        let query = "
1160            INSERT INTO channel_memberships (channel_id, user_id, admin)
1161            VALUES ($1, $2, $3)
1162            ON CONFLICT DO NOTHING
1163        ";
1164        Ok(sqlx::query(query)
1165            .bind(channel_id.0)
1166            .bind(user_id.0)
1167            .bind(is_admin)
1168            .execute(&self.pool)
1169            .await
1170            .map(drop)?)
1171    }
1172
1173    // messages
1174
1175    async fn create_channel_message(
1176        &self,
1177        channel_id: ChannelId,
1178        sender_id: UserId,
1179        body: &str,
1180        timestamp: OffsetDateTime,
1181        nonce: u128,
1182    ) -> Result<MessageId> {
1183        let query = "
1184            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1185            VALUES ($1, $2, $3, $4, $5)
1186            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1187            RETURNING id
1188        ";
1189        Ok(sqlx::query_scalar(query)
1190            .bind(channel_id.0)
1191            .bind(sender_id.0)
1192            .bind(body)
1193            .bind(timestamp)
1194            .bind(Uuid::from_u128(nonce))
1195            .fetch_one(&self.pool)
1196            .await
1197            .map(MessageId)?)
1198    }
1199
1200    async fn get_channel_messages(
1201        &self,
1202        channel_id: ChannelId,
1203        count: usize,
1204        before_id: Option<MessageId>,
1205    ) -> Result<Vec<ChannelMessage>> {
1206        let query = r#"
1207            SELECT * FROM (
1208                SELECT
1209                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1210                FROM
1211                    channel_messages
1212                WHERE
1213                    channel_id = $1 AND
1214                    id < $2
1215                ORDER BY id DESC
1216                LIMIT $3
1217            ) as recent_messages
1218            ORDER BY id ASC
1219        "#;
1220        Ok(sqlx::query_as(query)
1221            .bind(channel_id.0)
1222            .bind(before_id.unwrap_or(MessageId::MAX))
1223            .bind(count as i64)
1224            .fetch_all(&self.pool)
1225            .await?)
1226    }
1227
1228    #[cfg(test)]
1229    async fn teardown(&self, url: &str) {
1230        use util::ResultExt;
1231
1232        let query = "
1233            SELECT pg_terminate_backend(pg_stat_activity.pid)
1234            FROM pg_stat_activity
1235            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1236        ";
1237        sqlx::query(query).execute(&self.pool).await.log_err();
1238        self.pool.close().await;
1239        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1240            .await
1241            .log_err();
1242    }
1243
1244    #[cfg(test)]
1245    fn as_fake(&self) -> Option<&tests::FakeDb> {
1246        None
1247    }
1248}
1249
1250macro_rules! id_type {
1251    ($name:ident) => {
1252        #[derive(
1253            Clone,
1254            Copy,
1255            Debug,
1256            Default,
1257            PartialEq,
1258            Eq,
1259            PartialOrd,
1260            Ord,
1261            Hash,
1262            sqlx::Type,
1263            Serialize,
1264            Deserialize,
1265        )]
1266        #[sqlx(transparent)]
1267        #[serde(transparent)]
1268        pub struct $name(pub i32);
1269
1270        impl $name {
1271            #[allow(unused)]
1272            pub const MAX: Self = Self(i32::MAX);
1273
1274            #[allow(unused)]
1275            pub fn from_proto(value: u64) -> Self {
1276                Self(value as i32)
1277            }
1278
1279            #[allow(unused)]
1280            pub fn to_proto(&self) -> u64 {
1281                self.0 as u64
1282            }
1283        }
1284
1285        impl std::fmt::Display for $name {
1286            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1287                self.0.fmt(f)
1288            }
1289        }
1290    };
1291}
1292
1293id_type!(UserId);
1294#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1295pub struct User {
1296    pub id: UserId,
1297    pub github_login: String,
1298    pub email_address: Option<String>,
1299    pub admin: bool,
1300    pub invite_code: Option<String>,
1301    pub invite_count: i32,
1302    pub connected_once: bool,
1303}
1304
1305id_type!(ProjectId);
1306#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1307pub struct Project {
1308    pub id: ProjectId,
1309    pub host_user_id: UserId,
1310    pub unregistered: bool,
1311}
1312
1313#[derive(Clone, Debug, PartialEq, Serialize)]
1314pub struct UserActivitySummary {
1315    pub id: UserId,
1316    pub github_login: String,
1317    pub project_activity: Vec<(ProjectId, Duration)>,
1318}
1319
1320#[derive(Clone, Debug, PartialEq, Serialize)]
1321pub struct UserActivityPeriod {
1322    project_id: ProjectId,
1323    #[serde(with = "time::serde::iso8601")]
1324    start: OffsetDateTime,
1325    #[serde(with = "time::serde::iso8601")]
1326    end: OffsetDateTime,
1327    extensions: HashMap<String, usize>,
1328}
1329
1330id_type!(OrgId);
1331#[derive(FromRow)]
1332pub struct Org {
1333    pub id: OrgId,
1334    pub name: String,
1335    pub slug: String,
1336}
1337
1338id_type!(ChannelId);
1339#[derive(Clone, Debug, FromRow, Serialize)]
1340pub struct Channel {
1341    pub id: ChannelId,
1342    pub name: String,
1343    pub owner_id: i32,
1344    pub owner_is_user: bool,
1345}
1346
1347id_type!(MessageId);
1348#[derive(Clone, Debug, FromRow)]
1349pub struct ChannelMessage {
1350    pub id: MessageId,
1351    pub channel_id: ChannelId,
1352    pub sender_id: UserId,
1353    pub body: String,
1354    pub sent_at: OffsetDateTime,
1355    pub nonce: Uuid,
1356}
1357
1358#[derive(Clone, Debug, PartialEq, Eq)]
1359pub enum Contact {
1360    Accepted {
1361        user_id: UserId,
1362        should_notify: bool,
1363    },
1364    Outgoing {
1365        user_id: UserId,
1366    },
1367    Incoming {
1368        user_id: UserId,
1369        should_notify: bool,
1370    },
1371}
1372
1373impl Contact {
1374    pub fn user_id(&self) -> UserId {
1375        match self {
1376            Contact::Accepted { user_id, .. } => *user_id,
1377            Contact::Outgoing { user_id } => *user_id,
1378            Contact::Incoming { user_id, .. } => *user_id,
1379        }
1380    }
1381}
1382
1383#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1384pub struct IncomingContactRequest {
1385    pub requester_id: UserId,
1386    pub should_notify: bool,
1387}
1388
1389fn fuzzy_like_string(string: &str) -> String {
1390    let mut result = String::with_capacity(string.len() * 2 + 1);
1391    for c in string.chars() {
1392        if c.is_alphanumeric() {
1393            result.push('%');
1394            result.push(c);
1395        }
1396    }
1397    result.push('%');
1398    result
1399}
1400
1401#[cfg(test)]
1402pub mod tests {
1403    use super::*;
1404    use anyhow::anyhow;
1405    use collections::BTreeMap;
1406    use gpui::executor::{Background, Deterministic};
1407    use lazy_static::lazy_static;
1408    use parking_lot::Mutex;
1409    use rand::prelude::*;
1410    use sqlx::{
1411        migrate::{MigrateDatabase, Migrator},
1412        Postgres,
1413    };
1414    use std::{path::Path, sync::Arc};
1415    use util::post_inc;
1416
1417    #[tokio::test(flavor = "multi_thread")]
1418    async fn test_get_users_by_ids() {
1419        for test_db in [
1420            TestDb::postgres().await,
1421            TestDb::fake(build_background_executor()),
1422        ] {
1423            let db = test_db.db();
1424
1425            let user = db.create_user("user", None, false).await.unwrap();
1426            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1427            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1428            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1429
1430            assert_eq!(
1431                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1432                    .await
1433                    .unwrap(),
1434                vec![
1435                    User {
1436                        id: user,
1437                        github_login: "user".to_string(),
1438                        admin: false,
1439                        ..Default::default()
1440                    },
1441                    User {
1442                        id: friend1,
1443                        github_login: "friend-1".to_string(),
1444                        admin: false,
1445                        ..Default::default()
1446                    },
1447                    User {
1448                        id: friend2,
1449                        github_login: "friend-2".to_string(),
1450                        admin: false,
1451                        ..Default::default()
1452                    },
1453                    User {
1454                        id: friend3,
1455                        github_login: "friend-3".to_string(),
1456                        admin: false,
1457                        ..Default::default()
1458                    }
1459                ]
1460            );
1461        }
1462    }
1463
1464    #[tokio::test(flavor = "multi_thread")]
1465    async fn test_create_users() {
1466        let db = TestDb::postgres().await;
1467        let db = db.db();
1468
1469        // Create the first batch of users, ensuring invite counts are assigned
1470        // correctly and the respective invite codes are unique.
1471        let user_ids_batch_1 = db
1472            .create_users(vec![
1473                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1474                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1475                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1476            ])
1477            .await
1478            .unwrap();
1479        assert_eq!(user_ids_batch_1.len(), 3);
1480
1481        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1482        assert_eq!(users.len(), 3);
1483        assert_eq!(users[0].github_login, "user1");
1484        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1485        assert_eq!(users[0].invite_count, 5);
1486        assert_eq!(users[1].github_login, "user2");
1487        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1488        assert_eq!(users[1].invite_count, 4);
1489        assert_eq!(users[2].github_login, "user3");
1490        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1491        assert_eq!(users[2].invite_count, 3);
1492
1493        let invite_code_1 = users[0].invite_code.clone().unwrap();
1494        let invite_code_2 = users[1].invite_code.clone().unwrap();
1495        let invite_code_3 = users[2].invite_code.clone().unwrap();
1496        assert_ne!(invite_code_1, invite_code_2);
1497        assert_ne!(invite_code_1, invite_code_3);
1498        assert_ne!(invite_code_2, invite_code_3);
1499
1500        // Create the second batch of users and include a user that is already in the database, ensuring
1501        // the invite count for the existing user is updated without changing their invite code.
1502        let user_ids_batch_2 = db
1503            .create_users(vec![
1504                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1505                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1506            ])
1507            .await
1508            .unwrap();
1509        assert_eq!(user_ids_batch_2.len(), 2);
1510        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1511
1512        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1513        assert_eq!(users.len(), 2);
1514        assert_eq!(users[0].github_login, "user2");
1515        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1516        assert_eq!(users[0].invite_count, 10);
1517        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1518        assert_eq!(users[1].github_login, "user4");
1519        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1520        assert_eq!(users[1].invite_count, 2);
1521
1522        let invite_code_4 = users[1].invite_code.clone().unwrap();
1523        assert_ne!(invite_code_4, invite_code_1);
1524        assert_ne!(invite_code_4, invite_code_2);
1525        assert_ne!(invite_code_4, invite_code_3);
1526    }
1527
1528    #[tokio::test(flavor = "multi_thread")]
1529    async fn test_worktree_extensions() {
1530        let test_db = TestDb::postgres().await;
1531        let db = test_db.db();
1532
1533        let user = db.create_user("user_1", None, false).await.unwrap();
1534        let project = db.register_project(user).await.unwrap();
1535
1536        db.update_worktree_extensions(project, 100, Default::default())
1537            .await
1538            .unwrap();
1539        db.update_worktree_extensions(
1540            project,
1541            100,
1542            [("rs".to_string(), 5), ("md".to_string(), 3)]
1543                .into_iter()
1544                .collect(),
1545        )
1546        .await
1547        .unwrap();
1548        db.update_worktree_extensions(
1549            project,
1550            100,
1551            [("rs".to_string(), 6), ("md".to_string(), 5)]
1552                .into_iter()
1553                .collect(),
1554        )
1555        .await
1556        .unwrap();
1557        db.update_worktree_extensions(
1558            project,
1559            101,
1560            [("ts".to_string(), 2), ("md".to_string(), 1)]
1561                .into_iter()
1562                .collect(),
1563        )
1564        .await
1565        .unwrap();
1566
1567        assert_eq!(
1568            db.get_project_extensions(project).await.unwrap(),
1569            [
1570                (
1571                    100,
1572                    [("rs".into(), 6), ("md".into(), 5),]
1573                        .into_iter()
1574                        .collect::<HashMap<_, _>>()
1575                ),
1576                (
1577                    101,
1578                    [("ts".into(), 2), ("md".into(), 1),]
1579                        .into_iter()
1580                        .collect::<HashMap<_, _>>()
1581                )
1582            ]
1583            .into_iter()
1584            .collect()
1585        );
1586    }
1587
1588    #[tokio::test(flavor = "multi_thread")]
1589    async fn test_user_activity() {
1590        let test_db = TestDb::postgres().await;
1591        let db = test_db.db();
1592
1593        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1594        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1595        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1596        let project_1 = db.register_project(user_1).await.unwrap();
1597        db.update_worktree_extensions(
1598            project_1,
1599            1,
1600            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1601        )
1602        .await
1603        .unwrap();
1604        let project_2 = db.register_project(user_2).await.unwrap();
1605        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1606
1607        // User 2 opens a project
1608        let t1 = t0 + Duration::from_secs(10);
1609        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1610            .await
1611            .unwrap();
1612
1613        let t2 = t1 + Duration::from_secs(10);
1614        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1615            .await
1616            .unwrap();
1617
1618        // User 1 joins the project
1619        let t3 = t2 + Duration::from_secs(10);
1620        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1621            .await
1622            .unwrap();
1623
1624        // User 1 opens another project
1625        let t4 = t3 + Duration::from_secs(10);
1626        db.record_user_activity(
1627            t3..t4,
1628            &[
1629                (user_2, project_2),
1630                (user_1, project_2),
1631                (user_1, project_1),
1632            ],
1633        )
1634        .await
1635        .unwrap();
1636
1637        // User 3 joins that project
1638        let t5 = t4 + Duration::from_secs(10);
1639        db.record_user_activity(
1640            t4..t5,
1641            &[
1642                (user_2, project_2),
1643                (user_1, project_2),
1644                (user_1, project_1),
1645                (user_3, project_1),
1646            ],
1647        )
1648        .await
1649        .unwrap();
1650
1651        // User 2 leaves
1652        let t6 = t5 + Duration::from_secs(5);
1653        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1654            .await
1655            .unwrap();
1656
1657        let t7 = t6 + Duration::from_secs(60);
1658        let t8 = t7 + Duration::from_secs(10);
1659        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1660            .await
1661            .unwrap();
1662
1663        assert_eq!(
1664            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1665            &[
1666                UserActivitySummary {
1667                    id: user_1,
1668                    github_login: "user_1".to_string(),
1669                    project_activity: vec![
1670                        (project_1, Duration::from_secs(25)),
1671                        (project_2, Duration::from_secs(30)),
1672                    ]
1673                },
1674                UserActivitySummary {
1675                    id: user_2,
1676                    github_login: "user_2".to_string(),
1677                    project_activity: vec![(project_2, Duration::from_secs(50))]
1678                },
1679                UserActivitySummary {
1680                    id: user_3,
1681                    github_login: "user_3".to_string(),
1682                    project_activity: vec![(project_1, Duration::from_secs(15))]
1683                },
1684            ]
1685        );
1686
1687        assert_eq!(
1688            db.get_active_user_count(t0..t6, Duration::from_secs(56))
1689                .await
1690                .unwrap(),
1691            0
1692        );
1693        assert_eq!(
1694            db.get_active_user_count(t0..t6, Duration::from_secs(54))
1695                .await
1696                .unwrap(),
1697            1
1698        );
1699        assert_eq!(
1700            db.get_active_user_count(t0..t6, Duration::from_secs(30))
1701                .await
1702                .unwrap(),
1703            2
1704        );
1705        assert_eq!(
1706            db.get_active_user_count(t0..t6, Duration::from_secs(10))
1707                .await
1708                .unwrap(),
1709            3
1710        );
1711
1712        assert_eq!(
1713            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1714            &[
1715                UserActivityPeriod {
1716                    project_id: project_1,
1717                    start: t3,
1718                    end: t6,
1719                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1720                },
1721                UserActivityPeriod {
1722                    project_id: project_2,
1723                    start: t3,
1724                    end: t5,
1725                    extensions: Default::default(),
1726                },
1727            ]
1728        );
1729        assert_eq!(
1730            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1731            &[
1732                UserActivityPeriod {
1733                    project_id: project_2,
1734                    start: t2,
1735                    end: t5,
1736                    extensions: Default::default(),
1737                },
1738                UserActivityPeriod {
1739                    project_id: project_1,
1740                    start: t3,
1741                    end: t6,
1742                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1743                },
1744                UserActivityPeriod {
1745                    project_id: project_1,
1746                    start: t7,
1747                    end: t8,
1748                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1749                },
1750            ]
1751        );
1752    }
1753
1754    #[tokio::test(flavor = "multi_thread")]
1755    async fn test_recent_channel_messages() {
1756        for test_db in [
1757            TestDb::postgres().await,
1758            TestDb::fake(build_background_executor()),
1759        ] {
1760            let db = test_db.db();
1761            let user = db.create_user("user", None, false).await.unwrap();
1762            let org = db.create_org("org", "org").await.unwrap();
1763            let channel = db.create_org_channel(org, "channel").await.unwrap();
1764            for i in 0..10 {
1765                db.create_channel_message(
1766                    channel,
1767                    user,
1768                    &i.to_string(),
1769                    OffsetDateTime::now_utc(),
1770                    i,
1771                )
1772                .await
1773                .unwrap();
1774            }
1775
1776            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1777            assert_eq!(
1778                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1779                ["5", "6", "7", "8", "9"]
1780            );
1781
1782            let prev_messages = db
1783                .get_channel_messages(channel, 4, Some(messages[0].id))
1784                .await
1785                .unwrap();
1786            assert_eq!(
1787                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1788                ["1", "2", "3", "4"]
1789            );
1790        }
1791    }
1792
1793    #[tokio::test(flavor = "multi_thread")]
1794    async fn test_channel_message_nonces() {
1795        for test_db in [
1796            TestDb::postgres().await,
1797            TestDb::fake(build_background_executor()),
1798        ] {
1799            let db = test_db.db();
1800            let user = db.create_user("user", None, false).await.unwrap();
1801            let org = db.create_org("org", "org").await.unwrap();
1802            let channel = db.create_org_channel(org, "channel").await.unwrap();
1803
1804            let msg1_id = db
1805                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1806                .await
1807                .unwrap();
1808            let msg2_id = db
1809                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1810                .await
1811                .unwrap();
1812            let msg3_id = db
1813                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1814                .await
1815                .unwrap();
1816            let msg4_id = db
1817                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1818                .await
1819                .unwrap();
1820
1821            assert_ne!(msg1_id, msg2_id);
1822            assert_eq!(msg1_id, msg3_id);
1823            assert_eq!(msg2_id, msg4_id);
1824        }
1825    }
1826
1827    #[tokio::test(flavor = "multi_thread")]
1828    async fn test_create_access_tokens() {
1829        let test_db = TestDb::postgres().await;
1830        let db = test_db.db();
1831        let user = db.create_user("the-user", None, false).await.unwrap();
1832
1833        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1834        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1835        assert_eq!(
1836            db.get_access_token_hashes(user).await.unwrap(),
1837            &["h2".to_string(), "h1".to_string()]
1838        );
1839
1840        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1841        assert_eq!(
1842            db.get_access_token_hashes(user).await.unwrap(),
1843            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1844        );
1845
1846        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1847        assert_eq!(
1848            db.get_access_token_hashes(user).await.unwrap(),
1849            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1850        );
1851
1852        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1853        assert_eq!(
1854            db.get_access_token_hashes(user).await.unwrap(),
1855            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1856        );
1857    }
1858
1859    #[test]
1860    fn test_fuzzy_like_string() {
1861        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1862        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1863        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1864    }
1865
1866    #[tokio::test(flavor = "multi_thread")]
1867    async fn test_fuzzy_search_users() {
1868        let test_db = TestDb::postgres().await;
1869        let db = test_db.db();
1870        for github_login in [
1871            "California",
1872            "colorado",
1873            "oregon",
1874            "washington",
1875            "florida",
1876            "delaware",
1877            "rhode-island",
1878        ] {
1879            db.create_user(github_login, None, false).await.unwrap();
1880        }
1881
1882        assert_eq!(
1883            fuzzy_search_user_names(db, "clr").await,
1884            &["colorado", "California"]
1885        );
1886        assert_eq!(
1887            fuzzy_search_user_names(db, "ro").await,
1888            &["rhode-island", "colorado", "oregon"],
1889        );
1890
1891        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1892            db.fuzzy_search_users(query, 10)
1893                .await
1894                .unwrap()
1895                .into_iter()
1896                .map(|user| user.github_login)
1897                .collect::<Vec<_>>()
1898        }
1899    }
1900
1901    #[tokio::test(flavor = "multi_thread")]
1902    async fn test_add_contacts() {
1903        for test_db in [
1904            TestDb::postgres().await,
1905            TestDb::fake(build_background_executor()),
1906        ] {
1907            let db = test_db.db();
1908
1909            let user_1 = db.create_user("user1", None, false).await.unwrap();
1910            let user_2 = db.create_user("user2", None, false).await.unwrap();
1911            let user_3 = db.create_user("user3", None, false).await.unwrap();
1912
1913            // User starts with no contacts
1914            assert_eq!(
1915                db.get_contacts(user_1).await.unwrap(),
1916                vec![Contact::Accepted {
1917                    user_id: user_1,
1918                    should_notify: false
1919                }],
1920            );
1921
1922            // User requests a contact. Both users see the pending request.
1923            db.send_contact_request(user_1, user_2).await.unwrap();
1924            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1925            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1926            assert_eq!(
1927                db.get_contacts(user_1).await.unwrap(),
1928                &[
1929                    Contact::Accepted {
1930                        user_id: user_1,
1931                        should_notify: false
1932                    },
1933                    Contact::Outgoing { user_id: user_2 }
1934                ],
1935            );
1936            assert_eq!(
1937                db.get_contacts(user_2).await.unwrap(),
1938                &[
1939                    Contact::Incoming {
1940                        user_id: user_1,
1941                        should_notify: true
1942                    },
1943                    Contact::Accepted {
1944                        user_id: user_2,
1945                        should_notify: false
1946                    },
1947                ]
1948            );
1949
1950            // User 2 dismisses the contact request notification without accepting or rejecting.
1951            // We shouldn't notify them again.
1952            db.dismiss_contact_notification(user_1, user_2)
1953                .await
1954                .unwrap_err();
1955            db.dismiss_contact_notification(user_2, user_1)
1956                .await
1957                .unwrap();
1958            assert_eq!(
1959                db.get_contacts(user_2).await.unwrap(),
1960                &[
1961                    Contact::Incoming {
1962                        user_id: user_1,
1963                        should_notify: false
1964                    },
1965                    Contact::Accepted {
1966                        user_id: user_2,
1967                        should_notify: false
1968                    },
1969                ]
1970            );
1971
1972            // User can't accept their own contact request
1973            db.respond_to_contact_request(user_1, user_2, true)
1974                .await
1975                .unwrap_err();
1976
1977            // User accepts a contact request. Both users see the contact.
1978            db.respond_to_contact_request(user_2, user_1, true)
1979                .await
1980                .unwrap();
1981            assert_eq!(
1982                db.get_contacts(user_1).await.unwrap(),
1983                &[
1984                    Contact::Accepted {
1985                        user_id: user_1,
1986                        should_notify: false
1987                    },
1988                    Contact::Accepted {
1989                        user_id: user_2,
1990                        should_notify: true
1991                    }
1992                ],
1993            );
1994            assert!(db.has_contact(user_1, user_2).await.unwrap());
1995            assert!(db.has_contact(user_2, user_1).await.unwrap());
1996            assert_eq!(
1997                db.get_contacts(user_2).await.unwrap(),
1998                &[
1999                    Contact::Accepted {
2000                        user_id: user_1,
2001                        should_notify: false,
2002                    },
2003                    Contact::Accepted {
2004                        user_id: user_2,
2005                        should_notify: false,
2006                    },
2007                ]
2008            );
2009
2010            // Users cannot re-request existing contacts.
2011            db.send_contact_request(user_1, user_2).await.unwrap_err();
2012            db.send_contact_request(user_2, user_1).await.unwrap_err();
2013
2014            // Users can't dismiss notifications of them accepting other users' requests.
2015            db.dismiss_contact_notification(user_2, user_1)
2016                .await
2017                .unwrap_err();
2018            assert_eq!(
2019                db.get_contacts(user_1).await.unwrap(),
2020                &[
2021                    Contact::Accepted {
2022                        user_id: user_1,
2023                        should_notify: false
2024                    },
2025                    Contact::Accepted {
2026                        user_id: user_2,
2027                        should_notify: true,
2028                    },
2029                ]
2030            );
2031
2032            // Users can dismiss notifications of other users accepting their requests.
2033            db.dismiss_contact_notification(user_1, user_2)
2034                .await
2035                .unwrap();
2036            assert_eq!(
2037                db.get_contacts(user_1).await.unwrap(),
2038                &[
2039                    Contact::Accepted {
2040                        user_id: user_1,
2041                        should_notify: false
2042                    },
2043                    Contact::Accepted {
2044                        user_id: user_2,
2045                        should_notify: false,
2046                    },
2047                ]
2048            );
2049
2050            // Users send each other concurrent contact requests and
2051            // see that they are immediately accepted.
2052            db.send_contact_request(user_1, user_3).await.unwrap();
2053            db.send_contact_request(user_3, user_1).await.unwrap();
2054            assert_eq!(
2055                db.get_contacts(user_1).await.unwrap(),
2056                &[
2057                    Contact::Accepted {
2058                        user_id: user_1,
2059                        should_notify: false
2060                    },
2061                    Contact::Accepted {
2062                        user_id: user_2,
2063                        should_notify: false,
2064                    },
2065                    Contact::Accepted {
2066                        user_id: user_3,
2067                        should_notify: false
2068                    },
2069                ]
2070            );
2071            assert_eq!(
2072                db.get_contacts(user_3).await.unwrap(),
2073                &[
2074                    Contact::Accepted {
2075                        user_id: user_1,
2076                        should_notify: false
2077                    },
2078                    Contact::Accepted {
2079                        user_id: user_3,
2080                        should_notify: false
2081                    }
2082                ],
2083            );
2084
2085            // User declines a contact request. Both users see that it is gone.
2086            db.send_contact_request(user_2, user_3).await.unwrap();
2087            db.respond_to_contact_request(user_3, user_2, false)
2088                .await
2089                .unwrap();
2090            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2091            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2092            assert_eq!(
2093                db.get_contacts(user_2).await.unwrap(),
2094                &[
2095                    Contact::Accepted {
2096                        user_id: user_1,
2097                        should_notify: false
2098                    },
2099                    Contact::Accepted {
2100                        user_id: user_2,
2101                        should_notify: false
2102                    }
2103                ]
2104            );
2105            assert_eq!(
2106                db.get_contacts(user_3).await.unwrap(),
2107                &[
2108                    Contact::Accepted {
2109                        user_id: user_1,
2110                        should_notify: false
2111                    },
2112                    Contact::Accepted {
2113                        user_id: user_3,
2114                        should_notify: false
2115                    }
2116                ],
2117            );
2118        }
2119    }
2120
2121    #[tokio::test(flavor = "multi_thread")]
2122    async fn test_invite_codes() {
2123        let postgres = TestDb::postgres().await;
2124        let db = postgres.db();
2125        let user1 = db.create_user("user-1", None, false).await.unwrap();
2126
2127        // Initially, user 1 has no invite code
2128        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2129
2130        // Setting invite count to 0 when no code is assigned does not assign a new code
2131        db.set_invite_count(user1, 0).await.unwrap();
2132        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2133
2134        // User 1 creates an invite code that can be used twice.
2135        db.set_invite_count(user1, 2).await.unwrap();
2136        let (invite_code, invite_count) =
2137            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2138        assert_eq!(invite_count, 2);
2139
2140        // User 2 redeems the invite code and becomes a contact of user 1.
2141        let user2 = db
2142            .redeem_invite_code(&invite_code, "user-2", None)
2143            .await
2144            .unwrap();
2145        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2146        assert_eq!(invite_count, 1);
2147        assert_eq!(
2148            db.get_contacts(user1).await.unwrap(),
2149            [
2150                Contact::Accepted {
2151                    user_id: user1,
2152                    should_notify: false
2153                },
2154                Contact::Accepted {
2155                    user_id: user2,
2156                    should_notify: true
2157                }
2158            ]
2159        );
2160        assert_eq!(
2161            db.get_contacts(user2).await.unwrap(),
2162            [
2163                Contact::Accepted {
2164                    user_id: user1,
2165                    should_notify: false
2166                },
2167                Contact::Accepted {
2168                    user_id: user2,
2169                    should_notify: false
2170                }
2171            ]
2172        );
2173
2174        // User 3 redeems the invite code and becomes a contact of user 1.
2175        let user3 = db
2176            .redeem_invite_code(&invite_code, "user-3", None)
2177            .await
2178            .unwrap();
2179        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2180        assert_eq!(invite_count, 0);
2181        assert_eq!(
2182            db.get_contacts(user1).await.unwrap(),
2183            [
2184                Contact::Accepted {
2185                    user_id: user1,
2186                    should_notify: false
2187                },
2188                Contact::Accepted {
2189                    user_id: user2,
2190                    should_notify: true
2191                },
2192                Contact::Accepted {
2193                    user_id: user3,
2194                    should_notify: true
2195                }
2196            ]
2197        );
2198        assert_eq!(
2199            db.get_contacts(user3).await.unwrap(),
2200            [
2201                Contact::Accepted {
2202                    user_id: user1,
2203                    should_notify: false
2204                },
2205                Contact::Accepted {
2206                    user_id: user3,
2207                    should_notify: false
2208                },
2209            ]
2210        );
2211
2212        // Trying to reedem the code for the third time results in an error.
2213        db.redeem_invite_code(&invite_code, "user-4", None)
2214            .await
2215            .unwrap_err();
2216
2217        // Invite count can be updated after the code has been created.
2218        db.set_invite_count(user1, 2).await.unwrap();
2219        let (latest_code, invite_count) =
2220            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2221        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2222        assert_eq!(invite_count, 2);
2223
2224        // User 4 can now redeem the invite code and becomes a contact of user 1.
2225        let user4 = db
2226            .redeem_invite_code(&invite_code, "user-4", None)
2227            .await
2228            .unwrap();
2229        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2230        assert_eq!(invite_count, 1);
2231        assert_eq!(
2232            db.get_contacts(user1).await.unwrap(),
2233            [
2234                Contact::Accepted {
2235                    user_id: user1,
2236                    should_notify: false
2237                },
2238                Contact::Accepted {
2239                    user_id: user2,
2240                    should_notify: true
2241                },
2242                Contact::Accepted {
2243                    user_id: user3,
2244                    should_notify: true
2245                },
2246                Contact::Accepted {
2247                    user_id: user4,
2248                    should_notify: true
2249                }
2250            ]
2251        );
2252        assert_eq!(
2253            db.get_contacts(user4).await.unwrap(),
2254            [
2255                Contact::Accepted {
2256                    user_id: user1,
2257                    should_notify: false
2258                },
2259                Contact::Accepted {
2260                    user_id: user4,
2261                    should_notify: false
2262                },
2263            ]
2264        );
2265
2266        // An existing user cannot redeem invite codes.
2267        db.redeem_invite_code(&invite_code, "user-2", None)
2268            .await
2269            .unwrap_err();
2270        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2271        assert_eq!(invite_count, 1);
2272    }
2273
2274    pub struct TestDb {
2275        pub db: Option<Arc<dyn Db>>,
2276        pub url: String,
2277    }
2278
2279    impl TestDb {
2280        pub async fn postgres() -> Self {
2281            lazy_static! {
2282                static ref LOCK: Mutex<()> = Mutex::new(());
2283            }
2284
2285            let _guard = LOCK.lock();
2286            let mut rng = StdRng::from_entropy();
2287            let name = format!("zed-test-{}", rng.gen::<u128>());
2288            let url = format!("postgres://postgres@localhost/{}", name);
2289            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2290            Postgres::create_database(&url)
2291                .await
2292                .expect("failed to create test db");
2293            let db = PostgresDb::new(&url, 5).await.unwrap();
2294            let migrator = Migrator::new(migrations_path).await.unwrap();
2295            migrator.run(&db.pool).await.unwrap();
2296            Self {
2297                db: Some(Arc::new(db)),
2298                url,
2299            }
2300        }
2301
2302        pub fn fake(background: Arc<Background>) -> Self {
2303            Self {
2304                db: Some(Arc::new(FakeDb::new(background))),
2305                url: Default::default(),
2306            }
2307        }
2308
2309        pub fn db(&self) -> &Arc<dyn Db> {
2310            self.db.as_ref().unwrap()
2311        }
2312    }
2313
2314    impl Drop for TestDb {
2315        fn drop(&mut self) {
2316            if let Some(db) = self.db.take() {
2317                futures::executor::block_on(db.teardown(&self.url));
2318            }
2319        }
2320    }
2321
2322    pub struct FakeDb {
2323        background: Arc<Background>,
2324        pub users: Mutex<BTreeMap<UserId, User>>,
2325        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2326        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2327        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2328        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2329        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2330        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2331        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2332        pub contacts: Mutex<Vec<FakeContact>>,
2333        next_channel_message_id: Mutex<i32>,
2334        next_user_id: Mutex<i32>,
2335        next_org_id: Mutex<i32>,
2336        next_channel_id: Mutex<i32>,
2337        next_project_id: Mutex<i32>,
2338    }
2339
2340    #[derive(Debug)]
2341    pub struct FakeContact {
2342        pub requester_id: UserId,
2343        pub responder_id: UserId,
2344        pub accepted: bool,
2345        pub should_notify: bool,
2346    }
2347
2348    impl FakeDb {
2349        pub fn new(background: Arc<Background>) -> Self {
2350            Self {
2351                background,
2352                users: Default::default(),
2353                next_user_id: Mutex::new(0),
2354                projects: Default::default(),
2355                worktree_extensions: Default::default(),
2356                next_project_id: Mutex::new(1),
2357                orgs: Default::default(),
2358                next_org_id: Mutex::new(1),
2359                org_memberships: Default::default(),
2360                channels: Default::default(),
2361                next_channel_id: Mutex::new(1),
2362                channel_memberships: Default::default(),
2363                channel_messages: Default::default(),
2364                next_channel_message_id: Mutex::new(1),
2365                contacts: Default::default(),
2366            }
2367        }
2368    }
2369
2370    #[async_trait]
2371    impl Db for FakeDb {
2372        async fn create_user(
2373            &self,
2374            github_login: &str,
2375            email_address: Option<&str>,
2376            admin: bool,
2377        ) -> Result<UserId> {
2378            self.background.simulate_random_delay().await;
2379
2380            let mut users = self.users.lock();
2381            if let Some(user) = users
2382                .values()
2383                .find(|user| user.github_login == github_login)
2384            {
2385                Ok(user.id)
2386            } else {
2387                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2388                users.insert(
2389                    user_id,
2390                    User {
2391                        id: user_id,
2392                        github_login: github_login.to_string(),
2393                        email_address: email_address.map(str::to_string),
2394                        admin,
2395                        invite_code: None,
2396                        invite_count: 0,
2397                        connected_once: false,
2398                    },
2399                );
2400                Ok(user_id)
2401            }
2402        }
2403
2404        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2405            unimplemented!()
2406        }
2407
2408        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2409            unimplemented!()
2410        }
2411
2412        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2413            unimplemented!()
2414        }
2415
2416        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2417            self.background.simulate_random_delay().await;
2418            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2419        }
2420
2421        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2422            self.background.simulate_random_delay().await;
2423            let users = self.users.lock();
2424            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2425        }
2426
2427        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2428            unimplemented!()
2429        }
2430
2431        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2432            self.background.simulate_random_delay().await;
2433            Ok(self
2434                .users
2435                .lock()
2436                .values()
2437                .find(|user| user.github_login == github_login)
2438                .cloned())
2439        }
2440
2441        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2442            unimplemented!()
2443        }
2444
2445        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2446            self.background.simulate_random_delay().await;
2447            let mut users = self.users.lock();
2448            let mut user = users
2449                .get_mut(&id)
2450                .ok_or_else(|| anyhow!("user not found"))?;
2451            user.connected_once = connected_once;
2452            Ok(())
2453        }
2454
2455        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2456            unimplemented!()
2457        }
2458
2459        // invite codes
2460
2461        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2462            unimplemented!()
2463        }
2464
2465        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2466            self.background.simulate_random_delay().await;
2467            Ok(None)
2468        }
2469
2470        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2471            unimplemented!()
2472        }
2473
2474        async fn redeem_invite_code(
2475            &self,
2476            _code: &str,
2477            _login: &str,
2478            _email_address: Option<&str>,
2479        ) -> Result<UserId> {
2480            unimplemented!()
2481        }
2482
2483        // projects
2484
2485        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2486            self.background.simulate_random_delay().await;
2487            if !self.users.lock().contains_key(&host_user_id) {
2488                Err(anyhow!("no such user"))?;
2489            }
2490
2491            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2492            self.projects.lock().insert(
2493                project_id,
2494                Project {
2495                    id: project_id,
2496                    host_user_id,
2497                    unregistered: false,
2498                },
2499            );
2500            Ok(project_id)
2501        }
2502
2503        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2504            self.background.simulate_random_delay().await;
2505            self.projects
2506                .lock()
2507                .get_mut(&project_id)
2508                .ok_or_else(|| anyhow!("no such project"))?
2509                .unregistered = true;
2510            Ok(())
2511        }
2512
2513        async fn update_worktree_extensions(
2514            &self,
2515            project_id: ProjectId,
2516            worktree_id: u64,
2517            extensions: HashMap<String, u32>,
2518        ) -> Result<()> {
2519            self.background.simulate_random_delay().await;
2520            if !self.projects.lock().contains_key(&project_id) {
2521                Err(anyhow!("no such project"))?;
2522            }
2523
2524            for (extension, count) in extensions {
2525                self.worktree_extensions
2526                    .lock()
2527                    .insert((project_id, worktree_id, extension), count);
2528            }
2529
2530            Ok(())
2531        }
2532
2533        async fn get_project_extensions(
2534            &self,
2535            _project_id: ProjectId,
2536        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2537            unimplemented!()
2538        }
2539
2540        async fn record_user_activity(
2541            &self,
2542            _time_period: Range<OffsetDateTime>,
2543            _active_projects: &[(UserId, ProjectId)],
2544        ) -> Result<()> {
2545            unimplemented!()
2546        }
2547
2548        async fn get_active_user_count(
2549            &self,
2550            _time_period: Range<OffsetDateTime>,
2551            _min_duration: Duration,
2552        ) -> Result<usize> {
2553            unimplemented!()
2554        }
2555
2556        async fn get_top_users_activity_summary(
2557            &self,
2558            _time_period: Range<OffsetDateTime>,
2559            _limit: usize,
2560        ) -> Result<Vec<UserActivitySummary>> {
2561            unimplemented!()
2562        }
2563
2564        async fn get_user_activity_timeline(
2565            &self,
2566            _time_period: Range<OffsetDateTime>,
2567            _user_id: UserId,
2568        ) -> Result<Vec<UserActivityPeriod>> {
2569            unimplemented!()
2570        }
2571
2572        // contacts
2573
2574        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2575            self.background.simulate_random_delay().await;
2576            let mut contacts = vec![Contact::Accepted {
2577                user_id: id,
2578                should_notify: false,
2579            }];
2580
2581            for contact in self.contacts.lock().iter() {
2582                if contact.requester_id == id {
2583                    if contact.accepted {
2584                        contacts.push(Contact::Accepted {
2585                            user_id: contact.responder_id,
2586                            should_notify: contact.should_notify,
2587                        });
2588                    } else {
2589                        contacts.push(Contact::Outgoing {
2590                            user_id: contact.responder_id,
2591                        });
2592                    }
2593                } else if contact.responder_id == id {
2594                    if contact.accepted {
2595                        contacts.push(Contact::Accepted {
2596                            user_id: contact.requester_id,
2597                            should_notify: false,
2598                        });
2599                    } else {
2600                        contacts.push(Contact::Incoming {
2601                            user_id: contact.requester_id,
2602                            should_notify: contact.should_notify,
2603                        });
2604                    }
2605                }
2606            }
2607
2608            contacts.sort_unstable_by_key(|contact| contact.user_id());
2609            Ok(contacts)
2610        }
2611
2612        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2613            self.background.simulate_random_delay().await;
2614            Ok(self.contacts.lock().iter().any(|contact| {
2615                contact.accepted
2616                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2617                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2618            }))
2619        }
2620
2621        async fn send_contact_request(
2622            &self,
2623            requester_id: UserId,
2624            responder_id: UserId,
2625        ) -> Result<()> {
2626            self.background.simulate_random_delay().await;
2627            let mut contacts = self.contacts.lock();
2628            for contact in contacts.iter_mut() {
2629                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2630                    if contact.accepted {
2631                        Err(anyhow!("contact already exists"))?;
2632                    } else {
2633                        Err(anyhow!("contact already requested"))?;
2634                    }
2635                }
2636                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2637                    if contact.accepted {
2638                        Err(anyhow!("contact already exists"))?;
2639                    } else {
2640                        contact.accepted = true;
2641                        contact.should_notify = false;
2642                        return Ok(());
2643                    }
2644                }
2645            }
2646            contacts.push(FakeContact {
2647                requester_id,
2648                responder_id,
2649                accepted: false,
2650                should_notify: true,
2651            });
2652            Ok(())
2653        }
2654
2655        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2656            self.background.simulate_random_delay().await;
2657            self.contacts.lock().retain(|contact| {
2658                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2659            });
2660            Ok(())
2661        }
2662
2663        async fn dismiss_contact_notification(
2664            &self,
2665            user_id: UserId,
2666            contact_user_id: UserId,
2667        ) -> Result<()> {
2668            self.background.simulate_random_delay().await;
2669            let mut contacts = self.contacts.lock();
2670            for contact in contacts.iter_mut() {
2671                if contact.requester_id == contact_user_id
2672                    && contact.responder_id == user_id
2673                    && !contact.accepted
2674                {
2675                    contact.should_notify = false;
2676                    return Ok(());
2677                }
2678                if contact.requester_id == user_id
2679                    && contact.responder_id == contact_user_id
2680                    && contact.accepted
2681                {
2682                    contact.should_notify = false;
2683                    return Ok(());
2684                }
2685            }
2686            Err(anyhow!("no such notification"))?
2687        }
2688
2689        async fn respond_to_contact_request(
2690            &self,
2691            responder_id: UserId,
2692            requester_id: UserId,
2693            accept: bool,
2694        ) -> Result<()> {
2695            self.background.simulate_random_delay().await;
2696            let mut contacts = self.contacts.lock();
2697            for (ix, contact) in contacts.iter_mut().enumerate() {
2698                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2699                    if contact.accepted {
2700                        Err(anyhow!("contact already confirmed"))?;
2701                    }
2702                    if accept {
2703                        contact.accepted = true;
2704                        contact.should_notify = true;
2705                    } else {
2706                        contacts.remove(ix);
2707                    }
2708                    return Ok(());
2709                }
2710            }
2711            Err(anyhow!("no such contact request"))?
2712        }
2713
2714        async fn create_access_token_hash(
2715            &self,
2716            _user_id: UserId,
2717            _access_token_hash: &str,
2718            _max_access_token_count: usize,
2719        ) -> Result<()> {
2720            unimplemented!()
2721        }
2722
2723        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2724            unimplemented!()
2725        }
2726
2727        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2728            unimplemented!()
2729        }
2730
2731        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2732            self.background.simulate_random_delay().await;
2733            let mut orgs = self.orgs.lock();
2734            if orgs.values().any(|org| org.slug == slug) {
2735                Err(anyhow!("org already exists"))?
2736            } else {
2737                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2738                orgs.insert(
2739                    org_id,
2740                    Org {
2741                        id: org_id,
2742                        name: name.to_string(),
2743                        slug: slug.to_string(),
2744                    },
2745                );
2746                Ok(org_id)
2747            }
2748        }
2749
2750        async fn add_org_member(
2751            &self,
2752            org_id: OrgId,
2753            user_id: UserId,
2754            is_admin: bool,
2755        ) -> Result<()> {
2756            self.background.simulate_random_delay().await;
2757            if !self.orgs.lock().contains_key(&org_id) {
2758                Err(anyhow!("org does not exist"))?;
2759            }
2760            if !self.users.lock().contains_key(&user_id) {
2761                Err(anyhow!("user does not exist"))?;
2762            }
2763
2764            self.org_memberships
2765                .lock()
2766                .entry((org_id, user_id))
2767                .or_insert(is_admin);
2768            Ok(())
2769        }
2770
2771        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2772            self.background.simulate_random_delay().await;
2773            if !self.orgs.lock().contains_key(&org_id) {
2774                Err(anyhow!("org does not exist"))?;
2775            }
2776
2777            let mut channels = self.channels.lock();
2778            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2779            channels.insert(
2780                channel_id,
2781                Channel {
2782                    id: channel_id,
2783                    name: name.to_string(),
2784                    owner_id: org_id.0,
2785                    owner_is_user: false,
2786                },
2787            );
2788            Ok(channel_id)
2789        }
2790
2791        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2792            self.background.simulate_random_delay().await;
2793            Ok(self
2794                .channels
2795                .lock()
2796                .values()
2797                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2798                .cloned()
2799                .collect())
2800        }
2801
2802        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2803            self.background.simulate_random_delay().await;
2804            let channels = self.channels.lock();
2805            let memberships = self.channel_memberships.lock();
2806            Ok(channels
2807                .values()
2808                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2809                .cloned()
2810                .collect())
2811        }
2812
2813        async fn can_user_access_channel(
2814            &self,
2815            user_id: UserId,
2816            channel_id: ChannelId,
2817        ) -> Result<bool> {
2818            self.background.simulate_random_delay().await;
2819            Ok(self
2820                .channel_memberships
2821                .lock()
2822                .contains_key(&(channel_id, user_id)))
2823        }
2824
2825        async fn add_channel_member(
2826            &self,
2827            channel_id: ChannelId,
2828            user_id: UserId,
2829            is_admin: bool,
2830        ) -> Result<()> {
2831            self.background.simulate_random_delay().await;
2832            if !self.channels.lock().contains_key(&channel_id) {
2833                Err(anyhow!("channel does not exist"))?;
2834            }
2835            if !self.users.lock().contains_key(&user_id) {
2836                Err(anyhow!("user does not exist"))?;
2837            }
2838
2839            self.channel_memberships
2840                .lock()
2841                .entry((channel_id, user_id))
2842                .or_insert(is_admin);
2843            Ok(())
2844        }
2845
2846        async fn create_channel_message(
2847            &self,
2848            channel_id: ChannelId,
2849            sender_id: UserId,
2850            body: &str,
2851            timestamp: OffsetDateTime,
2852            nonce: u128,
2853        ) -> Result<MessageId> {
2854            self.background.simulate_random_delay().await;
2855            if !self.channels.lock().contains_key(&channel_id) {
2856                Err(anyhow!("channel does not exist"))?;
2857            }
2858            if !self.users.lock().contains_key(&sender_id) {
2859                Err(anyhow!("user does not exist"))?;
2860            }
2861
2862            let mut messages = self.channel_messages.lock();
2863            if let Some(message) = messages
2864                .values()
2865                .find(|message| message.nonce.as_u128() == nonce)
2866            {
2867                Ok(message.id)
2868            } else {
2869                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2870                messages.insert(
2871                    message_id,
2872                    ChannelMessage {
2873                        id: message_id,
2874                        channel_id,
2875                        sender_id,
2876                        body: body.to_string(),
2877                        sent_at: timestamp,
2878                        nonce: Uuid::from_u128(nonce),
2879                    },
2880                );
2881                Ok(message_id)
2882            }
2883        }
2884
2885        async fn get_channel_messages(
2886            &self,
2887            channel_id: ChannelId,
2888            count: usize,
2889            before_id: Option<MessageId>,
2890        ) -> Result<Vec<ChannelMessage>> {
2891            self.background.simulate_random_delay().await;
2892            let mut messages = self
2893                .channel_messages
2894                .lock()
2895                .values()
2896                .rev()
2897                .filter(|message| {
2898                    message.channel_id == channel_id
2899                        && message.id < before_id.unwrap_or(MessageId::MAX)
2900                })
2901                .take(count)
2902                .cloned()
2903                .collect::<Vec<_>>();
2904            messages.sort_unstable_by_key(|message| message.id);
2905            Ok(messages)
2906        }
2907
2908        async fn teardown(&self, _: &str) {}
2909
2910        #[cfg(test)]
2911        fn as_fake(&self) -> Option<&FakeDb> {
2912            Some(self)
2913        }
2914    }
2915
2916    fn build_background_executor() -> Arc<Background> {
2917        Deterministic::new(0).build_background()
2918    }
2919}