db.rs

   1use std::{ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use nanoid::nanoid;
  10use serde::Serialize;
  11pub use sqlx::postgres::PgPoolOptions as DbOptions;
  12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[async_trait]
  16pub trait Db: Send + Sync {
  17    async fn create_user(
  18        &self,
  19        github_login: &str,
  20        email_address: Option<&str>,
  21        admin: bool,
  22    ) -> Result<UserId>;
  23    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  24    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  28    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  29    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  30    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  31    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33
  34    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  35    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  36    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  37    async fn redeem_invite_code(
  38        &self,
  39        code: &str,
  40        login: &str,
  41        email_address: Option<&str>,
  42    ) -> Result<UserId>;
  43
  44    /// Registers a new project for the given user.
  45    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  46
  47    /// Unregisters a project for the given project id.
  48    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  49
  50    /// Update file counts by extension for the given project and worktree.
  51    async fn update_worktree_extensions(
  52        &self,
  53        project_id: ProjectId,
  54        worktree_id: u64,
  55        extensions: HashMap<String, usize>,
  56    ) -> Result<()>;
  57
  58    /// Get the file counts on the given project keyed by their worktree and extension.
  59    async fn get_project_extensions(
  60        &self,
  61        project_id: ProjectId,
  62    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  63
  64    /// Record which users have been active in which projects during
  65    /// a given period of time.
  66    async fn record_project_activity(
  67        &self,
  68        time_period: Range<OffsetDateTime>,
  69        active_projects: &[(UserId, ProjectId)],
  70    ) -> Result<()>;
  71
  72    /// Get the users that have been most active during the given time period,
  73    /// along with the amount of time they have been active in each project.
  74    async fn summarize_project_activity(
  75        &self,
  76        time_period: Range<OffsetDateTime>,
  77        max_user_count: usize,
  78    ) -> Result<Vec<UserActivitySummary>>;
  79
  80    /// Get the project activity for the given user and time period.
  81    async fn summarize_user_activity(
  82        &self,
  83        user_id: UserId,
  84        time_period: Range<OffsetDateTime>,
  85    ) -> Result<Vec<UserActivityDuration>>;
  86
  87    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  88    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  89    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  90    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  91    async fn dismiss_contact_notification(
  92        &self,
  93        responder_id: UserId,
  94        requester_id: UserId,
  95    ) -> Result<()>;
  96    async fn respond_to_contact_request(
  97        &self,
  98        responder_id: UserId,
  99        requester_id: UserId,
 100        accept: bool,
 101    ) -> Result<()>;
 102
 103    async fn create_access_token_hash(
 104        &self,
 105        user_id: UserId,
 106        access_token_hash: &str,
 107        max_access_token_count: usize,
 108    ) -> Result<()>;
 109    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 110    #[cfg(any(test, feature = "seed-support"))]
 111
 112    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 113    #[cfg(any(test, feature = "seed-support"))]
 114    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 115    #[cfg(any(test, feature = "seed-support"))]
 116    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 117    #[cfg(any(test, feature = "seed-support"))]
 118    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 119    #[cfg(any(test, feature = "seed-support"))]
 120
 121    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 122    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 123    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 124        -> Result<bool>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn add_channel_member(
 127        &self,
 128        channel_id: ChannelId,
 129        user_id: UserId,
 130        is_admin: bool,
 131    ) -> Result<()>;
 132    async fn create_channel_message(
 133        &self,
 134        channel_id: ChannelId,
 135        sender_id: UserId,
 136        body: &str,
 137        timestamp: OffsetDateTime,
 138        nonce: u128,
 139    ) -> Result<MessageId>;
 140    async fn get_channel_messages(
 141        &self,
 142        channel_id: ChannelId,
 143        count: usize,
 144        before_id: Option<MessageId>,
 145    ) -> Result<Vec<ChannelMessage>>;
 146    #[cfg(test)]
 147    async fn teardown(&self, url: &str);
 148    #[cfg(test)]
 149    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 150}
 151
 152pub struct PostgresDb {
 153    pool: sqlx::PgPool,
 154}
 155
 156impl PostgresDb {
 157    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 158        let pool = DbOptions::new()
 159            .max_connections(max_connections)
 160            .connect(&url)
 161            .await
 162            .context("failed to connect to postgres database")?;
 163        Ok(Self { pool })
 164    }
 165}
 166
 167#[async_trait]
 168impl Db for PostgresDb {
 169    // users
 170
 171    async fn create_user(
 172        &self,
 173        github_login: &str,
 174        email_address: Option<&str>,
 175        admin: bool,
 176    ) -> Result<UserId> {
 177        let query = "
 178            INSERT INTO users (github_login, email_address, admin)
 179            VALUES ($1, $2, $3)
 180            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 181            RETURNING id
 182        ";
 183        Ok(sqlx::query_scalar(query)
 184            .bind(github_login)
 185            .bind(email_address)
 186            .bind(admin)
 187            .fetch_one(&self.pool)
 188            .await
 189            .map(UserId)?)
 190    }
 191
 192    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 193        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 194        Ok(sqlx::query_as(query)
 195            .bind(limit as i32)
 196            .bind((page * limit) as i32)
 197            .fetch_all(&self.pool)
 198            .await?)
 199    }
 200
 201    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 202        let mut query = QueryBuilder::new(
 203            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 204        );
 205        query.push_values(
 206            users,
 207            |mut query, (github_login, email_address, invite_count)| {
 208                query
 209                    .push_bind(github_login)
 210                    .push_bind(email_address)
 211                    .push_bind(false)
 212                    .push_bind(nanoid!(16))
 213                    .push_bind(invite_count as i32);
 214            },
 215        );
 216        query.push(
 217            "
 218            ON CONFLICT (github_login) DO UPDATE SET
 219                github_login = excluded.github_login,
 220                invite_count = excluded.invite_count,
 221                invite_code = CASE WHEN users.invite_code IS NULL
 222                                   THEN excluded.invite_code
 223                                   ELSE users.invite_code
 224                              END
 225            RETURNING id
 226            ",
 227        );
 228
 229        let rows = query.build().fetch_all(&self.pool).await?;
 230        Ok(rows
 231            .into_iter()
 232            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 233            .collect())
 234    }
 235
 236    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 237        let like_string = fuzzy_like_string(name_query);
 238        let query = "
 239            SELECT users.*
 240            FROM users
 241            WHERE github_login ILIKE $1
 242            ORDER BY github_login <-> $2
 243            LIMIT $3
 244        ";
 245        Ok(sqlx::query_as(query)
 246            .bind(like_string)
 247            .bind(name_query)
 248            .bind(limit as i32)
 249            .fetch_all(&self.pool)
 250            .await?)
 251    }
 252
 253    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 254        let users = self.get_users_by_ids(vec![id]).await?;
 255        Ok(users.into_iter().next())
 256    }
 257
 258    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 259        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 260        let query = "
 261            SELECT users.*
 262            FROM users
 263            WHERE users.id = ANY ($1)
 264        ";
 265        Ok(sqlx::query_as(query)
 266            .bind(&ids)
 267            .fetch_all(&self.pool)
 268            .await?)
 269    }
 270
 271    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 272        let query = format!(
 273            "
 274            SELECT users.*
 275            FROM users
 276            WHERE invite_count = 0
 277            AND inviter_id IS{} NULL
 278            ",
 279            if invited_by_another_user { " NOT" } else { "" }
 280        );
 281
 282        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 283    }
 284
 285    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 286        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 287        Ok(sqlx::query_as(query)
 288            .bind(github_login)
 289            .fetch_optional(&self.pool)
 290            .await?)
 291    }
 292
 293    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 294        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 295        Ok(sqlx::query(query)
 296            .bind(is_admin)
 297            .bind(id.0)
 298            .execute(&self.pool)
 299            .await
 300            .map(drop)?)
 301    }
 302
 303    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 304        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 305        Ok(sqlx::query(query)
 306            .bind(connected_once)
 307            .bind(id.0)
 308            .execute(&self.pool)
 309            .await
 310            .map(drop)?)
 311    }
 312
 313    async fn destroy_user(&self, id: UserId) -> Result<()> {
 314        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 315        sqlx::query(query)
 316            .bind(id.0)
 317            .execute(&self.pool)
 318            .await
 319            .map(drop)?;
 320        let query = "DELETE FROM users WHERE id = $1;";
 321        Ok(sqlx::query(query)
 322            .bind(id.0)
 323            .execute(&self.pool)
 324            .await
 325            .map(drop)?)
 326    }
 327
 328    // invite codes
 329
 330    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 331        let mut tx = self.pool.begin().await?;
 332        if count > 0 {
 333            sqlx::query(
 334                "
 335                UPDATE users
 336                SET invite_code = $1
 337                WHERE id = $2 AND invite_code IS NULL
 338            ",
 339            )
 340            .bind(nanoid!(16))
 341            .bind(id)
 342            .execute(&mut tx)
 343            .await?;
 344        }
 345
 346        sqlx::query(
 347            "
 348            UPDATE users
 349            SET invite_count = $1
 350            WHERE id = $2
 351            ",
 352        )
 353        .bind(count as i32)
 354        .bind(id)
 355        .execute(&mut tx)
 356        .await?;
 357        tx.commit().await?;
 358        Ok(())
 359    }
 360
 361    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 362        let result: Option<(String, i32)> = sqlx::query_as(
 363            "
 364                SELECT invite_code, invite_count
 365                FROM users
 366                WHERE id = $1 AND invite_code IS NOT NULL 
 367            ",
 368        )
 369        .bind(id)
 370        .fetch_optional(&self.pool)
 371        .await?;
 372        if let Some((code, count)) = result {
 373            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 374        } else {
 375            Ok(None)
 376        }
 377    }
 378
 379    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 380        sqlx::query_as(
 381            "
 382                SELECT *
 383                FROM users
 384                WHERE invite_code = $1
 385            ",
 386        )
 387        .bind(code)
 388        .fetch_optional(&self.pool)
 389        .await?
 390        .ok_or_else(|| {
 391            Error::Http(
 392                StatusCode::NOT_FOUND,
 393                "that invite code does not exist".to_string(),
 394            )
 395        })
 396    }
 397
 398    async fn redeem_invite_code(
 399        &self,
 400        code: &str,
 401        login: &str,
 402        email_address: Option<&str>,
 403    ) -> Result<UserId> {
 404        let mut tx = self.pool.begin().await?;
 405
 406        let inviter_id: Option<UserId> = sqlx::query_scalar(
 407            "
 408                UPDATE users
 409                SET invite_count = invite_count - 1
 410                WHERE
 411                    invite_code = $1 AND
 412                    invite_count > 0
 413                RETURNING id
 414            ",
 415        )
 416        .bind(code)
 417        .fetch_optional(&mut tx)
 418        .await?;
 419
 420        let inviter_id = match inviter_id {
 421            Some(inviter_id) => inviter_id,
 422            None => {
 423                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 424                    .bind(code)
 425                    .fetch_optional(&mut tx)
 426                    .await?
 427                    .is_some()
 428                {
 429                    Err(Error::Http(
 430                        StatusCode::UNAUTHORIZED,
 431                        "no invites remaining".to_string(),
 432                    ))?
 433                } else {
 434                    Err(Error::Http(
 435                        StatusCode::NOT_FOUND,
 436                        "invite code not found".to_string(),
 437                    ))?
 438                }
 439            }
 440        };
 441
 442        let invitee_id = sqlx::query_scalar(
 443            "
 444                INSERT INTO users
 445                    (github_login, email_address, admin, inviter_id)
 446                VALUES
 447                    ($1, $2, 'f', $3)
 448                RETURNING id
 449            ",
 450        )
 451        .bind(login)
 452        .bind(email_address)
 453        .bind(inviter_id)
 454        .fetch_one(&mut tx)
 455        .await
 456        .map(UserId)?;
 457
 458        sqlx::query(
 459            "
 460                INSERT INTO contacts
 461                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 462                VALUES
 463                    ($1, $2, 't', 't', 't')
 464            ",
 465        )
 466        .bind(inviter_id)
 467        .bind(invitee_id)
 468        .execute(&mut tx)
 469        .await?;
 470
 471        tx.commit().await?;
 472        Ok(invitee_id)
 473    }
 474
 475    // projects
 476
 477    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 478        Ok(sqlx::query_scalar(
 479            "
 480            INSERT INTO projects(host_user_id)
 481            VALUES ($1)
 482            RETURNING id
 483            ",
 484        )
 485        .bind(host_user_id)
 486        .fetch_one(&self.pool)
 487        .await
 488        .map(ProjectId)?)
 489    }
 490
 491    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 492        sqlx::query(
 493            "
 494            UPDATE projects
 495            SET unregistered = 't'
 496            WHERE id = $1
 497            ",
 498        )
 499        .bind(project_id)
 500        .execute(&self.pool)
 501        .await?;
 502        Ok(())
 503    }
 504
 505    async fn update_worktree_extensions(
 506        &self,
 507        project_id: ProjectId,
 508        worktree_id: u64,
 509        extensions: HashMap<String, usize>,
 510    ) -> Result<()> {
 511        if extensions.is_empty() {
 512            return Ok(());
 513        }
 514
 515        let mut query = QueryBuilder::new(
 516            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 517        );
 518        query.push_values(extensions, |mut query, (extension, count)| {
 519            query
 520                .push_bind(project_id)
 521                .push_bind(worktree_id as i32)
 522                .push_bind(extension)
 523                .push_bind(count as i32);
 524        });
 525        query.push(
 526            "
 527            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 528            count = excluded.count
 529            ",
 530        );
 531        query.build().execute(&self.pool).await?;
 532
 533        Ok(())
 534    }
 535
 536    async fn get_project_extensions(
 537        &self,
 538        project_id: ProjectId,
 539    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 540        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 541        struct WorktreeExtension {
 542            worktree_id: i32,
 543            extension: String,
 544            count: i32,
 545        }
 546
 547        let query = "
 548            SELECT worktree_id, extension, count
 549            FROM worktree_extensions
 550            WHERE project_id = $1
 551        ";
 552        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 553            .bind(&project_id)
 554            .fetch_all(&self.pool)
 555            .await?;
 556
 557        let mut extension_counts = HashMap::default();
 558        for count in counts {
 559            extension_counts
 560                .entry(count.worktree_id as u64)
 561                .or_insert(HashMap::default())
 562                .insert(count.extension, count.count as usize);
 563        }
 564        Ok(extension_counts)
 565    }
 566
 567    async fn record_project_activity(
 568        &self,
 569        time_period: Range<OffsetDateTime>,
 570        projects: &[(UserId, ProjectId)],
 571    ) -> Result<()> {
 572        let query = "
 573            INSERT INTO project_activity_periods
 574            (ended_at, duration_millis, user_id, project_id)
 575            VALUES
 576            ($1, $2, $3, $4);
 577        ";
 578
 579        let mut tx = self.pool.begin().await?;
 580        let duration_millis =
 581            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 582        for (user_id, project_id) in projects {
 583            sqlx::query(query)
 584                .bind(time_period.end)
 585                .bind(duration_millis)
 586                .bind(user_id)
 587                .bind(project_id)
 588                .execute(&mut tx)
 589                .await?;
 590        }
 591        tx.commit().await?;
 592
 593        Ok(())
 594    }
 595
 596    async fn summarize_project_activity(
 597        &self,
 598        time_period: Range<OffsetDateTime>,
 599        max_user_count: usize,
 600    ) -> Result<Vec<UserActivitySummary>> {
 601        let query = "
 602            WITH
 603                project_durations AS (
 604                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 605                    FROM project_activity_periods
 606                    WHERE $1 < ended_at AND ended_at <= $2
 607                    GROUP BY user_id, project_id
 608                ),
 609                user_durations AS (
 610                    SELECT user_id, SUM(project_duration) as total_duration
 611                    FROM project_durations
 612                    GROUP BY user_id
 613                    ORDER BY total_duration DESC
 614                    LIMIT $3
 615                )
 616            SELECT user_durations.user_id, users.github_login, project_id, project_duration
 617            FROM user_durations, project_durations, users
 618            WHERE
 619                user_durations.user_id = project_durations.user_id AND
 620                user_durations.user_id = users.id
 621            ORDER BY user_id ASC, project_duration DESC
 622        ";
 623
 624        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
 625            .bind(time_period.start)
 626            .bind(time_period.end)
 627            .bind(max_user_count as i32)
 628            .fetch(&self.pool);
 629
 630        let mut result = Vec::<UserActivitySummary>::new();
 631        while let Some(row) = rows.next().await {
 632            let (user_id, github_login, project_id, duration_millis) = row?;
 633            let project_id = project_id;
 634            let duration = Duration::from_millis(duration_millis as u64);
 635            if let Some(last_summary) = result.last_mut() {
 636                if last_summary.id == user_id {
 637                    last_summary.project_activity.push((project_id, duration));
 638                    continue;
 639                }
 640            }
 641            result.push(UserActivitySummary {
 642                id: user_id,
 643                project_activity: vec![(project_id, duration)],
 644                github_login,
 645            });
 646        }
 647
 648        Ok(result)
 649    }
 650
 651    async fn summarize_user_activity(
 652        &self,
 653        user_id: UserId,
 654        time_period: Range<OffsetDateTime>,
 655    ) -> Result<Vec<UserActivityDuration>> {
 656        const COALESCE_THRESHOLD: Duration = Duration::from_secs(5);
 657
 658        let query = "
 659            SELECT
 660                project_activity_periods.ended_at,
 661                project_activity_periods.duration_millis,
 662                project_activity_periods.project_id,
 663                worktree_extensions.extension,
 664                worktree_extensions.count
 665            FROM project_activity_periods
 666            LEFT OUTER JOIN
 667                worktree_extensions
 668            ON
 669                project_activity_periods.project_id = worktree_extensions.project_id
 670            WHERE
 671                project_activity_periods.user_id = $1 AND
 672                $2 < project_activity_periods.ended_at AND
 673                project_activity_periods.ended_at <= $3
 674            ORDER BY project_activity_periods.id ASC
 675        ";
 676
 677        let mut rows = sqlx::query_as::<
 678            _,
 679            (
 680                PrimitiveDateTime,
 681                i32,
 682                ProjectId,
 683                Option<String>,
 684                Option<i32>,
 685            ),
 686        >(query)
 687        .bind(user_id)
 688        .bind(time_period.start)
 689        .bind(time_period.end)
 690        .fetch(&self.pool);
 691
 692        let mut durations: HashMap<ProjectId, Vec<UserActivityDuration>> = Default::default();
 693        while let Some(row) = rows.next().await {
 694            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 695            let ended_at = ended_at.assume_utc();
 696            let duration = Duration::from_millis(duration_millis as u64);
 697            let started_at = ended_at - duration;
 698            let project_durations = durations.entry(project_id).or_default();
 699
 700            if let Some(prev_duration) = project_durations.last_mut() {
 701                if started_at - prev_duration.end <= COALESCE_THRESHOLD {
 702                    prev_duration.end = ended_at;
 703                } else {
 704                    project_durations.push(UserActivityDuration {
 705                        project_id,
 706                        start: started_at,
 707                        end: ended_at,
 708                        extensions: Default::default(),
 709                    });
 710                }
 711            } else {
 712                project_durations.push(UserActivityDuration {
 713                    project_id,
 714                    start: started_at,
 715                    end: ended_at,
 716                    extensions: Default::default(),
 717                });
 718            }
 719
 720            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 721                project_durations
 722                    .last_mut()
 723                    .unwrap()
 724                    .extensions
 725                    .insert(extension, extension_count as usize);
 726            }
 727        }
 728
 729        let mut durations = durations.into_values().flatten().collect::<Vec<_>>();
 730        durations.sort_unstable_by_key(|duration| duration.start);
 731        Ok(durations)
 732    }
 733
 734    // contacts
 735
 736    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 737        let query = "
 738            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 739            FROM contacts
 740            WHERE user_id_a = $1 OR user_id_b = $1;
 741        ";
 742
 743        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 744            .bind(user_id)
 745            .fetch(&self.pool);
 746
 747        let mut contacts = vec![Contact::Accepted {
 748            user_id,
 749            should_notify: false,
 750        }];
 751        while let Some(row) = rows.next().await {
 752            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 753
 754            if user_id_a == user_id {
 755                if accepted {
 756                    contacts.push(Contact::Accepted {
 757                        user_id: user_id_b,
 758                        should_notify: should_notify && a_to_b,
 759                    });
 760                } else if a_to_b {
 761                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 762                } else {
 763                    contacts.push(Contact::Incoming {
 764                        user_id: user_id_b,
 765                        should_notify,
 766                    });
 767                }
 768            } else {
 769                if accepted {
 770                    contacts.push(Contact::Accepted {
 771                        user_id: user_id_a,
 772                        should_notify: should_notify && !a_to_b,
 773                    });
 774                } else if a_to_b {
 775                    contacts.push(Contact::Incoming {
 776                        user_id: user_id_a,
 777                        should_notify,
 778                    });
 779                } else {
 780                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 781                }
 782            }
 783        }
 784
 785        contacts.sort_unstable_by_key(|contact| contact.user_id());
 786
 787        Ok(contacts)
 788    }
 789
 790    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 791        let (id_a, id_b) = if user_id_1 < user_id_2 {
 792            (user_id_1, user_id_2)
 793        } else {
 794            (user_id_2, user_id_1)
 795        };
 796
 797        let query = "
 798            SELECT 1 FROM contacts
 799            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 800            LIMIT 1
 801        ";
 802        Ok(sqlx::query_scalar::<_, i32>(query)
 803            .bind(id_a.0)
 804            .bind(id_b.0)
 805            .fetch_optional(&self.pool)
 806            .await?
 807            .is_some())
 808    }
 809
 810    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 811        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 812            (sender_id, receiver_id, true)
 813        } else {
 814            (receiver_id, sender_id, false)
 815        };
 816        let query = "
 817            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 818            VALUES ($1, $2, $3, 'f', 't')
 819            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 820            SET
 821                accepted = 't',
 822                should_notify = 'f'
 823            WHERE
 824                NOT contacts.accepted AND
 825                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 826                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 827        ";
 828        let result = sqlx::query(query)
 829            .bind(id_a.0)
 830            .bind(id_b.0)
 831            .bind(a_to_b)
 832            .execute(&self.pool)
 833            .await?;
 834
 835        if result.rows_affected() == 1 {
 836            Ok(())
 837        } else {
 838            Err(anyhow!("contact already requested"))?
 839        }
 840    }
 841
 842    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 843        let (id_a, id_b) = if responder_id < requester_id {
 844            (responder_id, requester_id)
 845        } else {
 846            (requester_id, responder_id)
 847        };
 848        let query = "
 849            DELETE FROM contacts
 850            WHERE user_id_a = $1 AND user_id_b = $2;
 851        ";
 852        let result = sqlx::query(query)
 853            .bind(id_a.0)
 854            .bind(id_b.0)
 855            .execute(&self.pool)
 856            .await?;
 857
 858        if result.rows_affected() == 1 {
 859            Ok(())
 860        } else {
 861            Err(anyhow!("no such contact"))?
 862        }
 863    }
 864
 865    async fn dismiss_contact_notification(
 866        &self,
 867        user_id: UserId,
 868        contact_user_id: UserId,
 869    ) -> Result<()> {
 870        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 871            (user_id, contact_user_id, true)
 872        } else {
 873            (contact_user_id, user_id, false)
 874        };
 875
 876        let query = "
 877            UPDATE contacts
 878            SET should_notify = 'f'
 879            WHERE
 880                user_id_a = $1 AND user_id_b = $2 AND
 881                (
 882                    (a_to_b = $3 AND accepted) OR
 883                    (a_to_b != $3 AND NOT accepted)
 884                );
 885        ";
 886
 887        let result = sqlx::query(query)
 888            .bind(id_a.0)
 889            .bind(id_b.0)
 890            .bind(a_to_b)
 891            .execute(&self.pool)
 892            .await?;
 893
 894        if result.rows_affected() == 0 {
 895            Err(anyhow!("no such contact request"))?;
 896        }
 897
 898        Ok(())
 899    }
 900
 901    async fn respond_to_contact_request(
 902        &self,
 903        responder_id: UserId,
 904        requester_id: UserId,
 905        accept: bool,
 906    ) -> Result<()> {
 907        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 908            (responder_id, requester_id, false)
 909        } else {
 910            (requester_id, responder_id, true)
 911        };
 912        let result = if accept {
 913            let query = "
 914                UPDATE contacts
 915                SET accepted = 't', should_notify = 't'
 916                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 917            ";
 918            sqlx::query(query)
 919                .bind(id_a.0)
 920                .bind(id_b.0)
 921                .bind(a_to_b)
 922                .execute(&self.pool)
 923                .await?
 924        } else {
 925            let query = "
 926                DELETE FROM contacts
 927                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 928            ";
 929            sqlx::query(query)
 930                .bind(id_a.0)
 931                .bind(id_b.0)
 932                .bind(a_to_b)
 933                .execute(&self.pool)
 934                .await?
 935        };
 936        if result.rows_affected() == 1 {
 937            Ok(())
 938        } else {
 939            Err(anyhow!("no such contact request"))?
 940        }
 941    }
 942
 943    // access tokens
 944
 945    async fn create_access_token_hash(
 946        &self,
 947        user_id: UserId,
 948        access_token_hash: &str,
 949        max_access_token_count: usize,
 950    ) -> Result<()> {
 951        let insert_query = "
 952            INSERT INTO access_tokens (user_id, hash)
 953            VALUES ($1, $2);
 954        ";
 955        let cleanup_query = "
 956            DELETE FROM access_tokens
 957            WHERE id IN (
 958                SELECT id from access_tokens
 959                WHERE user_id = $1
 960                ORDER BY id DESC
 961                OFFSET $3
 962            )
 963        ";
 964
 965        let mut tx = self.pool.begin().await?;
 966        sqlx::query(insert_query)
 967            .bind(user_id.0)
 968            .bind(access_token_hash)
 969            .execute(&mut tx)
 970            .await?;
 971        sqlx::query(cleanup_query)
 972            .bind(user_id.0)
 973            .bind(access_token_hash)
 974            .bind(max_access_token_count as i32)
 975            .execute(&mut tx)
 976            .await?;
 977        Ok(tx.commit().await?)
 978    }
 979
 980    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 981        let query = "
 982            SELECT hash
 983            FROM access_tokens
 984            WHERE user_id = $1
 985            ORDER BY id DESC
 986        ";
 987        Ok(sqlx::query_scalar(query)
 988            .bind(user_id.0)
 989            .fetch_all(&self.pool)
 990            .await?)
 991    }
 992
 993    // orgs
 994
 995    #[allow(unused)] // Help rust-analyzer
 996    #[cfg(any(test, feature = "seed-support"))]
 997    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 998        let query = "
 999            SELECT *
1000            FROM orgs
1001            WHERE slug = $1
1002        ";
1003        Ok(sqlx::query_as(query)
1004            .bind(slug)
1005            .fetch_optional(&self.pool)
1006            .await?)
1007    }
1008
1009    #[cfg(any(test, feature = "seed-support"))]
1010    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1011        let query = "
1012            INSERT INTO orgs (name, slug)
1013            VALUES ($1, $2)
1014            RETURNING id
1015        ";
1016        Ok(sqlx::query_scalar(query)
1017            .bind(name)
1018            .bind(slug)
1019            .fetch_one(&self.pool)
1020            .await
1021            .map(OrgId)?)
1022    }
1023
1024    #[cfg(any(test, feature = "seed-support"))]
1025    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1026        let query = "
1027            INSERT INTO org_memberships (org_id, user_id, admin)
1028            VALUES ($1, $2, $3)
1029            ON CONFLICT DO NOTHING
1030        ";
1031        Ok(sqlx::query(query)
1032            .bind(org_id.0)
1033            .bind(user_id.0)
1034            .bind(is_admin)
1035            .execute(&self.pool)
1036            .await
1037            .map(drop)?)
1038    }
1039
1040    // channels
1041
1042    #[cfg(any(test, feature = "seed-support"))]
1043    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1044        let query = "
1045            INSERT INTO channels (owner_id, owner_is_user, name)
1046            VALUES ($1, false, $2)
1047            RETURNING id
1048        ";
1049        Ok(sqlx::query_scalar(query)
1050            .bind(org_id.0)
1051            .bind(name)
1052            .fetch_one(&self.pool)
1053            .await
1054            .map(ChannelId)?)
1055    }
1056
1057    #[allow(unused)] // Help rust-analyzer
1058    #[cfg(any(test, feature = "seed-support"))]
1059    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1060        let query = "
1061            SELECT *
1062            FROM channels
1063            WHERE
1064                channels.owner_is_user = false AND
1065                channels.owner_id = $1
1066        ";
1067        Ok(sqlx::query_as(query)
1068            .bind(org_id.0)
1069            .fetch_all(&self.pool)
1070            .await?)
1071    }
1072
1073    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1074        let query = "
1075            SELECT
1076                channels.*
1077            FROM
1078                channel_memberships, channels
1079            WHERE
1080                channel_memberships.user_id = $1 AND
1081                channel_memberships.channel_id = channels.id
1082        ";
1083        Ok(sqlx::query_as(query)
1084            .bind(user_id.0)
1085            .fetch_all(&self.pool)
1086            .await?)
1087    }
1088
1089    async fn can_user_access_channel(
1090        &self,
1091        user_id: UserId,
1092        channel_id: ChannelId,
1093    ) -> Result<bool> {
1094        let query = "
1095            SELECT id
1096            FROM channel_memberships
1097            WHERE user_id = $1 AND channel_id = $2
1098            LIMIT 1
1099        ";
1100        Ok(sqlx::query_scalar::<_, i32>(query)
1101            .bind(user_id.0)
1102            .bind(channel_id.0)
1103            .fetch_optional(&self.pool)
1104            .await
1105            .map(|e| e.is_some())?)
1106    }
1107
1108    #[cfg(any(test, feature = "seed-support"))]
1109    async fn add_channel_member(
1110        &self,
1111        channel_id: ChannelId,
1112        user_id: UserId,
1113        is_admin: bool,
1114    ) -> Result<()> {
1115        let query = "
1116            INSERT INTO channel_memberships (channel_id, user_id, admin)
1117            VALUES ($1, $2, $3)
1118            ON CONFLICT DO NOTHING
1119        ";
1120        Ok(sqlx::query(query)
1121            .bind(channel_id.0)
1122            .bind(user_id.0)
1123            .bind(is_admin)
1124            .execute(&self.pool)
1125            .await
1126            .map(drop)?)
1127    }
1128
1129    // messages
1130
1131    async fn create_channel_message(
1132        &self,
1133        channel_id: ChannelId,
1134        sender_id: UserId,
1135        body: &str,
1136        timestamp: OffsetDateTime,
1137        nonce: u128,
1138    ) -> Result<MessageId> {
1139        let query = "
1140            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1141            VALUES ($1, $2, $3, $4, $5)
1142            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1143            RETURNING id
1144        ";
1145        Ok(sqlx::query_scalar(query)
1146            .bind(channel_id.0)
1147            .bind(sender_id.0)
1148            .bind(body)
1149            .bind(timestamp)
1150            .bind(Uuid::from_u128(nonce))
1151            .fetch_one(&self.pool)
1152            .await
1153            .map(MessageId)?)
1154    }
1155
1156    async fn get_channel_messages(
1157        &self,
1158        channel_id: ChannelId,
1159        count: usize,
1160        before_id: Option<MessageId>,
1161    ) -> Result<Vec<ChannelMessage>> {
1162        let query = r#"
1163            SELECT * FROM (
1164                SELECT
1165                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1166                FROM
1167                    channel_messages
1168                WHERE
1169                    channel_id = $1 AND
1170                    id < $2
1171                ORDER BY id DESC
1172                LIMIT $3
1173            ) as recent_messages
1174            ORDER BY id ASC
1175        "#;
1176        Ok(sqlx::query_as(query)
1177            .bind(channel_id.0)
1178            .bind(before_id.unwrap_or(MessageId::MAX))
1179            .bind(count as i64)
1180            .fetch_all(&self.pool)
1181            .await?)
1182    }
1183
1184    #[cfg(test)]
1185    async fn teardown(&self, url: &str) {
1186        use util::ResultExt;
1187
1188        let query = "
1189            SELECT pg_terminate_backend(pg_stat_activity.pid)
1190            FROM pg_stat_activity
1191            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1192        ";
1193        sqlx::query(query).execute(&self.pool).await.log_err();
1194        self.pool.close().await;
1195        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1196            .await
1197            .log_err();
1198    }
1199
1200    #[cfg(test)]
1201    fn as_fake(&self) -> Option<&tests::FakeDb> {
1202        None
1203    }
1204}
1205
1206macro_rules! id_type {
1207    ($name:ident) => {
1208        #[derive(
1209            Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1210        )]
1211        #[sqlx(transparent)]
1212        #[serde(transparent)]
1213        pub struct $name(pub i32);
1214
1215        impl $name {
1216            #[allow(unused)]
1217            pub const MAX: Self = Self(i32::MAX);
1218
1219            #[allow(unused)]
1220            pub fn from_proto(value: u64) -> Self {
1221                Self(value as i32)
1222            }
1223
1224            #[allow(unused)]
1225            pub fn to_proto(&self) -> u64 {
1226                self.0 as u64
1227            }
1228        }
1229
1230        impl std::fmt::Display for $name {
1231            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1232                self.0.fmt(f)
1233            }
1234        }
1235    };
1236}
1237
1238id_type!(UserId);
1239#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1240pub struct User {
1241    pub id: UserId,
1242    pub github_login: String,
1243    pub email_address: Option<String>,
1244    pub admin: bool,
1245    pub invite_code: Option<String>,
1246    pub invite_count: i32,
1247    pub connected_once: bool,
1248}
1249
1250id_type!(ProjectId);
1251#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1252pub struct Project {
1253    pub id: ProjectId,
1254    pub host_user_id: UserId,
1255    pub unregistered: bool,
1256}
1257
1258#[derive(Clone, Debug, PartialEq, Serialize)]
1259pub struct UserActivitySummary {
1260    pub id: UserId,
1261    pub github_login: String,
1262    pub project_activity: Vec<(ProjectId, Duration)>,
1263}
1264
1265#[derive(Clone, Debug, PartialEq, Serialize)]
1266pub struct UserActivityDuration {
1267    project_id: ProjectId,
1268    start: OffsetDateTime,
1269    end: OffsetDateTime,
1270    extensions: HashMap<String, usize>,
1271}
1272
1273id_type!(OrgId);
1274#[derive(FromRow)]
1275pub struct Org {
1276    pub id: OrgId,
1277    pub name: String,
1278    pub slug: String,
1279}
1280
1281id_type!(ChannelId);
1282#[derive(Clone, Debug, FromRow, Serialize)]
1283pub struct Channel {
1284    pub id: ChannelId,
1285    pub name: String,
1286    pub owner_id: i32,
1287    pub owner_is_user: bool,
1288}
1289
1290id_type!(MessageId);
1291#[derive(Clone, Debug, FromRow)]
1292pub struct ChannelMessage {
1293    pub id: MessageId,
1294    pub channel_id: ChannelId,
1295    pub sender_id: UserId,
1296    pub body: String,
1297    pub sent_at: OffsetDateTime,
1298    pub nonce: Uuid,
1299}
1300
1301#[derive(Clone, Debug, PartialEq, Eq)]
1302pub enum Contact {
1303    Accepted {
1304        user_id: UserId,
1305        should_notify: bool,
1306    },
1307    Outgoing {
1308        user_id: UserId,
1309    },
1310    Incoming {
1311        user_id: UserId,
1312        should_notify: bool,
1313    },
1314}
1315
1316impl Contact {
1317    pub fn user_id(&self) -> UserId {
1318        match self {
1319            Contact::Accepted { user_id, .. } => *user_id,
1320            Contact::Outgoing { user_id } => *user_id,
1321            Contact::Incoming { user_id, .. } => *user_id,
1322        }
1323    }
1324}
1325
1326#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1327pub struct IncomingContactRequest {
1328    pub requester_id: UserId,
1329    pub should_notify: bool,
1330}
1331
1332fn fuzzy_like_string(string: &str) -> String {
1333    let mut result = String::with_capacity(string.len() * 2 + 1);
1334    for c in string.chars() {
1335        if c.is_alphanumeric() {
1336            result.push('%');
1337            result.push(c);
1338        }
1339    }
1340    result.push('%');
1341    result
1342}
1343
1344#[cfg(test)]
1345pub mod tests {
1346    use super::*;
1347    use anyhow::anyhow;
1348    use collections::BTreeMap;
1349    use gpui::executor::Background;
1350    use lazy_static::lazy_static;
1351    use parking_lot::Mutex;
1352    use rand::prelude::*;
1353    use sqlx::{
1354        migrate::{MigrateDatabase, Migrator},
1355        Postgres,
1356    };
1357    use std::{path::Path, sync::Arc};
1358    use util::post_inc;
1359
1360    #[tokio::test(flavor = "multi_thread")]
1361    async fn test_get_users_by_ids() {
1362        for test_db in [
1363            TestDb::postgres().await,
1364            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1365        ] {
1366            let db = test_db.db();
1367
1368            let user = db.create_user("user", None, false).await.unwrap();
1369            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1370            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1371            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1372
1373            assert_eq!(
1374                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1375                    .await
1376                    .unwrap(),
1377                vec![
1378                    User {
1379                        id: user,
1380                        github_login: "user".to_string(),
1381                        admin: false,
1382                        ..Default::default()
1383                    },
1384                    User {
1385                        id: friend1,
1386                        github_login: "friend-1".to_string(),
1387                        admin: false,
1388                        ..Default::default()
1389                    },
1390                    User {
1391                        id: friend2,
1392                        github_login: "friend-2".to_string(),
1393                        admin: false,
1394                        ..Default::default()
1395                    },
1396                    User {
1397                        id: friend3,
1398                        github_login: "friend-3".to_string(),
1399                        admin: false,
1400                        ..Default::default()
1401                    }
1402                ]
1403            );
1404        }
1405    }
1406
1407    #[tokio::test(flavor = "multi_thread")]
1408    async fn test_create_users() {
1409        let db = TestDb::postgres().await;
1410        let db = db.db();
1411
1412        // Create the first batch of users, ensuring invite counts are assigned
1413        // correctly and the respective invite codes are unique.
1414        let user_ids_batch_1 = db
1415            .create_users(vec![
1416                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1417                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1418                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1419            ])
1420            .await
1421            .unwrap();
1422        assert_eq!(user_ids_batch_1.len(), 3);
1423
1424        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1425        assert_eq!(users.len(), 3);
1426        assert_eq!(users[0].github_login, "user1");
1427        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1428        assert_eq!(users[0].invite_count, 5);
1429        assert_eq!(users[1].github_login, "user2");
1430        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1431        assert_eq!(users[1].invite_count, 4);
1432        assert_eq!(users[2].github_login, "user3");
1433        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1434        assert_eq!(users[2].invite_count, 3);
1435
1436        let invite_code_1 = users[0].invite_code.clone().unwrap();
1437        let invite_code_2 = users[1].invite_code.clone().unwrap();
1438        let invite_code_3 = users[2].invite_code.clone().unwrap();
1439        assert_ne!(invite_code_1, invite_code_2);
1440        assert_ne!(invite_code_1, invite_code_3);
1441        assert_ne!(invite_code_2, invite_code_3);
1442
1443        // Create the second batch of users and include a user that is already in the database, ensuring
1444        // the invite count for the existing user is updated without changing their invite code.
1445        let user_ids_batch_2 = db
1446            .create_users(vec![
1447                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1448                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1449            ])
1450            .await
1451            .unwrap();
1452        assert_eq!(user_ids_batch_2.len(), 2);
1453        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1454
1455        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1456        assert_eq!(users.len(), 2);
1457        assert_eq!(users[0].github_login, "user2");
1458        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1459        assert_eq!(users[0].invite_count, 10);
1460        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1461        assert_eq!(users[1].github_login, "user4");
1462        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1463        assert_eq!(users[1].invite_count, 2);
1464
1465        let invite_code_4 = users[1].invite_code.clone().unwrap();
1466        assert_ne!(invite_code_4, invite_code_1);
1467        assert_ne!(invite_code_4, invite_code_2);
1468        assert_ne!(invite_code_4, invite_code_3);
1469    }
1470
1471    #[tokio::test(flavor = "multi_thread")]
1472    async fn test_worktree_extensions() {
1473        let test_db = TestDb::postgres().await;
1474        let db = test_db.db();
1475
1476        let user = db.create_user("user_1", None, false).await.unwrap();
1477        let project = db.register_project(user).await.unwrap();
1478
1479        db.update_worktree_extensions(project, 100, Default::default())
1480            .await
1481            .unwrap();
1482        db.update_worktree_extensions(
1483            project,
1484            100,
1485            [("rs".to_string(), 5), ("md".to_string(), 3)]
1486                .into_iter()
1487                .collect(),
1488        )
1489        .await
1490        .unwrap();
1491        db.update_worktree_extensions(
1492            project,
1493            100,
1494            [("rs".to_string(), 6), ("md".to_string(), 5)]
1495                .into_iter()
1496                .collect(),
1497        )
1498        .await
1499        .unwrap();
1500        db.update_worktree_extensions(
1501            project,
1502            101,
1503            [("ts".to_string(), 2), ("md".to_string(), 1)]
1504                .into_iter()
1505                .collect(),
1506        )
1507        .await
1508        .unwrap();
1509
1510        assert_eq!(
1511            db.get_project_extensions(project).await.unwrap(),
1512            [
1513                (
1514                    100,
1515                    [("rs".into(), 6), ("md".into(), 5),]
1516                        .into_iter()
1517                        .collect::<HashMap<_, _>>()
1518                ),
1519                (
1520                    101,
1521                    [("ts".into(), 2), ("md".into(), 1),]
1522                        .into_iter()
1523                        .collect::<HashMap<_, _>>()
1524                )
1525            ]
1526            .into_iter()
1527            .collect()
1528        );
1529    }
1530
1531    #[tokio::test(flavor = "multi_thread")]
1532    async fn test_project_activity() {
1533        let test_db = TestDb::postgres().await;
1534        let db = test_db.db();
1535
1536        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1537        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1538        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1539        let project_1 = db.register_project(user_1).await.unwrap();
1540        db.update_worktree_extensions(
1541            project_1,
1542            1,
1543            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1544        )
1545        .await
1546        .unwrap();
1547        let project_2 = db.register_project(user_2).await.unwrap();
1548        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1549
1550        // User 2 opens a project
1551        let t1 = t0 + Duration::from_secs(10);
1552        db.record_project_activity(t0..t1, &[(user_2, project_2)])
1553            .await
1554            .unwrap();
1555
1556        let t2 = t1 + Duration::from_secs(10);
1557        db.record_project_activity(t1..t2, &[(user_2, project_2)])
1558            .await
1559            .unwrap();
1560
1561        // User 1 joins the project
1562        let t3 = t2 + Duration::from_secs(10);
1563        db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1564            .await
1565            .unwrap();
1566
1567        // User 1 opens another project
1568        let t4 = t3 + Duration::from_secs(10);
1569        db.record_project_activity(
1570            t3..t4,
1571            &[
1572                (user_2, project_2),
1573                (user_1, project_2),
1574                (user_1, project_1),
1575            ],
1576        )
1577        .await
1578        .unwrap();
1579
1580        // User 3 joins that project
1581        let t5 = t4 + Duration::from_secs(10);
1582        db.record_project_activity(
1583            t4..t5,
1584            &[
1585                (user_2, project_2),
1586                (user_1, project_2),
1587                (user_1, project_1),
1588                (user_3, project_1),
1589            ],
1590        )
1591        .await
1592        .unwrap();
1593
1594        // User 2 leaves
1595        let t6 = t5 + Duration::from_secs(5);
1596        db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1597            .await
1598            .unwrap();
1599
1600        let t7 = t6 + Duration::from_secs(60);
1601        let t8 = t7 + Duration::from_secs(10);
1602        db.record_project_activity(t7..t8, &[(user_1, project_1)])
1603            .await
1604            .unwrap();
1605
1606        assert_eq!(
1607            db.summarize_project_activity(t0..t6, 10).await.unwrap(),
1608            &[
1609                UserActivitySummary {
1610                    id: user_1,
1611                    github_login: "user_1".to_string(),
1612                    project_activity: vec![
1613                        (project_2, Duration::from_secs(30)),
1614                        (project_1, Duration::from_secs(25))
1615                    ]
1616                },
1617                UserActivitySummary {
1618                    id: user_2,
1619                    github_login: "user_2".to_string(),
1620                    project_activity: vec![(project_2, Duration::from_secs(50))]
1621                },
1622                UserActivitySummary {
1623                    id: user_3,
1624                    github_login: "user_3".to_string(),
1625                    project_activity: vec![(project_1, Duration::from_secs(15))]
1626                },
1627            ]
1628        );
1629        assert_eq!(
1630            db.summarize_user_activity(user_1, t3..t6).await.unwrap(),
1631            &[
1632                UserActivityDuration {
1633                    project_id: project_1,
1634                    start: t3,
1635                    end: t6,
1636                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1637                },
1638                UserActivityDuration {
1639                    project_id: project_2,
1640                    start: t3,
1641                    end: t5,
1642                    extensions: Default::default(),
1643                },
1644            ]
1645        );
1646        assert_eq!(
1647            db.summarize_user_activity(user_1, t0..t8).await.unwrap(),
1648            &[
1649                UserActivityDuration {
1650                    project_id: project_2,
1651                    start: t2,
1652                    end: t5,
1653                    extensions: Default::default(),
1654                },
1655                UserActivityDuration {
1656                    project_id: project_1,
1657                    start: t3,
1658                    end: t6,
1659                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1660                },
1661                UserActivityDuration {
1662                    project_id: project_1,
1663                    start: t7,
1664                    end: t8,
1665                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1666                },
1667            ]
1668        );
1669    }
1670
1671    #[tokio::test(flavor = "multi_thread")]
1672    async fn test_recent_channel_messages() {
1673        for test_db in [
1674            TestDb::postgres().await,
1675            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1676        ] {
1677            let db = test_db.db();
1678            let user = db.create_user("user", None, false).await.unwrap();
1679            let org = db.create_org("org", "org").await.unwrap();
1680            let channel = db.create_org_channel(org, "channel").await.unwrap();
1681            for i in 0..10 {
1682                db.create_channel_message(
1683                    channel,
1684                    user,
1685                    &i.to_string(),
1686                    OffsetDateTime::now_utc(),
1687                    i,
1688                )
1689                .await
1690                .unwrap();
1691            }
1692
1693            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1694            assert_eq!(
1695                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1696                ["5", "6", "7", "8", "9"]
1697            );
1698
1699            let prev_messages = db
1700                .get_channel_messages(channel, 4, Some(messages[0].id))
1701                .await
1702                .unwrap();
1703            assert_eq!(
1704                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1705                ["1", "2", "3", "4"]
1706            );
1707        }
1708    }
1709
1710    #[tokio::test(flavor = "multi_thread")]
1711    async fn test_channel_message_nonces() {
1712        for test_db in [
1713            TestDb::postgres().await,
1714            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1715        ] {
1716            let db = test_db.db();
1717            let user = db.create_user("user", None, false).await.unwrap();
1718            let org = db.create_org("org", "org").await.unwrap();
1719            let channel = db.create_org_channel(org, "channel").await.unwrap();
1720
1721            let msg1_id = db
1722                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1723                .await
1724                .unwrap();
1725            let msg2_id = db
1726                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1727                .await
1728                .unwrap();
1729            let msg3_id = db
1730                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1731                .await
1732                .unwrap();
1733            let msg4_id = db
1734                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1735                .await
1736                .unwrap();
1737
1738            assert_ne!(msg1_id, msg2_id);
1739            assert_eq!(msg1_id, msg3_id);
1740            assert_eq!(msg2_id, msg4_id);
1741        }
1742    }
1743
1744    #[tokio::test(flavor = "multi_thread")]
1745    async fn test_create_access_tokens() {
1746        let test_db = TestDb::postgres().await;
1747        let db = test_db.db();
1748        let user = db.create_user("the-user", None, false).await.unwrap();
1749
1750        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1751        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1752        assert_eq!(
1753            db.get_access_token_hashes(user).await.unwrap(),
1754            &["h2".to_string(), "h1".to_string()]
1755        );
1756
1757        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1758        assert_eq!(
1759            db.get_access_token_hashes(user).await.unwrap(),
1760            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1761        );
1762
1763        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1764        assert_eq!(
1765            db.get_access_token_hashes(user).await.unwrap(),
1766            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1767        );
1768
1769        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1770        assert_eq!(
1771            db.get_access_token_hashes(user).await.unwrap(),
1772            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1773        );
1774    }
1775
1776    #[test]
1777    fn test_fuzzy_like_string() {
1778        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1779        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1780        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1781    }
1782
1783    #[tokio::test(flavor = "multi_thread")]
1784    async fn test_fuzzy_search_users() {
1785        let test_db = TestDb::postgres().await;
1786        let db = test_db.db();
1787        for github_login in [
1788            "California",
1789            "colorado",
1790            "oregon",
1791            "washington",
1792            "florida",
1793            "delaware",
1794            "rhode-island",
1795        ] {
1796            db.create_user(github_login, None, false).await.unwrap();
1797        }
1798
1799        assert_eq!(
1800            fuzzy_search_user_names(db, "clr").await,
1801            &["colorado", "California"]
1802        );
1803        assert_eq!(
1804            fuzzy_search_user_names(db, "ro").await,
1805            &["rhode-island", "colorado", "oregon"],
1806        );
1807
1808        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1809            db.fuzzy_search_users(query, 10)
1810                .await
1811                .unwrap()
1812                .into_iter()
1813                .map(|user| user.github_login)
1814                .collect::<Vec<_>>()
1815        }
1816    }
1817
1818    #[tokio::test(flavor = "multi_thread")]
1819    async fn test_add_contacts() {
1820        for test_db in [
1821            TestDb::postgres().await,
1822            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1823        ] {
1824            let db = test_db.db();
1825
1826            let user_1 = db.create_user("user1", None, false).await.unwrap();
1827            let user_2 = db.create_user("user2", None, false).await.unwrap();
1828            let user_3 = db.create_user("user3", None, false).await.unwrap();
1829
1830            // User starts with no contacts
1831            assert_eq!(
1832                db.get_contacts(user_1).await.unwrap(),
1833                vec![Contact::Accepted {
1834                    user_id: user_1,
1835                    should_notify: false
1836                }],
1837            );
1838
1839            // User requests a contact. Both users see the pending request.
1840            db.send_contact_request(user_1, user_2).await.unwrap();
1841            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1842            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1843            assert_eq!(
1844                db.get_contacts(user_1).await.unwrap(),
1845                &[
1846                    Contact::Accepted {
1847                        user_id: user_1,
1848                        should_notify: false
1849                    },
1850                    Contact::Outgoing { user_id: user_2 }
1851                ],
1852            );
1853            assert_eq!(
1854                db.get_contacts(user_2).await.unwrap(),
1855                &[
1856                    Contact::Incoming {
1857                        user_id: user_1,
1858                        should_notify: true
1859                    },
1860                    Contact::Accepted {
1861                        user_id: user_2,
1862                        should_notify: false
1863                    },
1864                ]
1865            );
1866
1867            // User 2 dismisses the contact request notification without accepting or rejecting.
1868            // We shouldn't notify them again.
1869            db.dismiss_contact_notification(user_1, user_2)
1870                .await
1871                .unwrap_err();
1872            db.dismiss_contact_notification(user_2, user_1)
1873                .await
1874                .unwrap();
1875            assert_eq!(
1876                db.get_contacts(user_2).await.unwrap(),
1877                &[
1878                    Contact::Incoming {
1879                        user_id: user_1,
1880                        should_notify: false
1881                    },
1882                    Contact::Accepted {
1883                        user_id: user_2,
1884                        should_notify: false
1885                    },
1886                ]
1887            );
1888
1889            // User can't accept their own contact request
1890            db.respond_to_contact_request(user_1, user_2, true)
1891                .await
1892                .unwrap_err();
1893
1894            // User accepts a contact request. Both users see the contact.
1895            db.respond_to_contact_request(user_2, user_1, true)
1896                .await
1897                .unwrap();
1898            assert_eq!(
1899                db.get_contacts(user_1).await.unwrap(),
1900                &[
1901                    Contact::Accepted {
1902                        user_id: user_1,
1903                        should_notify: false
1904                    },
1905                    Contact::Accepted {
1906                        user_id: user_2,
1907                        should_notify: true
1908                    }
1909                ],
1910            );
1911            assert!(db.has_contact(user_1, user_2).await.unwrap());
1912            assert!(db.has_contact(user_2, user_1).await.unwrap());
1913            assert_eq!(
1914                db.get_contacts(user_2).await.unwrap(),
1915                &[
1916                    Contact::Accepted {
1917                        user_id: user_1,
1918                        should_notify: false,
1919                    },
1920                    Contact::Accepted {
1921                        user_id: user_2,
1922                        should_notify: false,
1923                    },
1924                ]
1925            );
1926
1927            // Users cannot re-request existing contacts.
1928            db.send_contact_request(user_1, user_2).await.unwrap_err();
1929            db.send_contact_request(user_2, user_1).await.unwrap_err();
1930
1931            // Users can't dismiss notifications of them accepting other users' requests.
1932            db.dismiss_contact_notification(user_2, user_1)
1933                .await
1934                .unwrap_err();
1935            assert_eq!(
1936                db.get_contacts(user_1).await.unwrap(),
1937                &[
1938                    Contact::Accepted {
1939                        user_id: user_1,
1940                        should_notify: false
1941                    },
1942                    Contact::Accepted {
1943                        user_id: user_2,
1944                        should_notify: true,
1945                    },
1946                ]
1947            );
1948
1949            // Users can dismiss notifications of other users accepting their requests.
1950            db.dismiss_contact_notification(user_1, user_2)
1951                .await
1952                .unwrap();
1953            assert_eq!(
1954                db.get_contacts(user_1).await.unwrap(),
1955                &[
1956                    Contact::Accepted {
1957                        user_id: user_1,
1958                        should_notify: false
1959                    },
1960                    Contact::Accepted {
1961                        user_id: user_2,
1962                        should_notify: false,
1963                    },
1964                ]
1965            );
1966
1967            // Users send each other concurrent contact requests and
1968            // see that they are immediately accepted.
1969            db.send_contact_request(user_1, user_3).await.unwrap();
1970            db.send_contact_request(user_3, user_1).await.unwrap();
1971            assert_eq!(
1972                db.get_contacts(user_1).await.unwrap(),
1973                &[
1974                    Contact::Accepted {
1975                        user_id: user_1,
1976                        should_notify: false
1977                    },
1978                    Contact::Accepted {
1979                        user_id: user_2,
1980                        should_notify: false,
1981                    },
1982                    Contact::Accepted {
1983                        user_id: user_3,
1984                        should_notify: false
1985                    },
1986                ]
1987            );
1988            assert_eq!(
1989                db.get_contacts(user_3).await.unwrap(),
1990                &[
1991                    Contact::Accepted {
1992                        user_id: user_1,
1993                        should_notify: false
1994                    },
1995                    Contact::Accepted {
1996                        user_id: user_3,
1997                        should_notify: false
1998                    }
1999                ],
2000            );
2001
2002            // User declines a contact request. Both users see that it is gone.
2003            db.send_contact_request(user_2, user_3).await.unwrap();
2004            db.respond_to_contact_request(user_3, user_2, false)
2005                .await
2006                .unwrap();
2007            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2008            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2009            assert_eq!(
2010                db.get_contacts(user_2).await.unwrap(),
2011                &[
2012                    Contact::Accepted {
2013                        user_id: user_1,
2014                        should_notify: false
2015                    },
2016                    Contact::Accepted {
2017                        user_id: user_2,
2018                        should_notify: false
2019                    }
2020                ]
2021            );
2022            assert_eq!(
2023                db.get_contacts(user_3).await.unwrap(),
2024                &[
2025                    Contact::Accepted {
2026                        user_id: user_1,
2027                        should_notify: false
2028                    },
2029                    Contact::Accepted {
2030                        user_id: user_3,
2031                        should_notify: false
2032                    }
2033                ],
2034            );
2035        }
2036    }
2037
2038    #[tokio::test(flavor = "multi_thread")]
2039    async fn test_invite_codes() {
2040        let postgres = TestDb::postgres().await;
2041        let db = postgres.db();
2042        let user1 = db.create_user("user-1", None, false).await.unwrap();
2043
2044        // Initially, user 1 has no invite code
2045        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2046
2047        // Setting invite count to 0 when no code is assigned does not assign a new code
2048        db.set_invite_count(user1, 0).await.unwrap();
2049        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2050
2051        // User 1 creates an invite code that can be used twice.
2052        db.set_invite_count(user1, 2).await.unwrap();
2053        let (invite_code, invite_count) =
2054            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2055        assert_eq!(invite_count, 2);
2056
2057        // User 2 redeems the invite code and becomes a contact of user 1.
2058        let user2 = db
2059            .redeem_invite_code(&invite_code, "user-2", None)
2060            .await
2061            .unwrap();
2062        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2063        assert_eq!(invite_count, 1);
2064        assert_eq!(
2065            db.get_contacts(user1).await.unwrap(),
2066            [
2067                Contact::Accepted {
2068                    user_id: user1,
2069                    should_notify: false
2070                },
2071                Contact::Accepted {
2072                    user_id: user2,
2073                    should_notify: true
2074                }
2075            ]
2076        );
2077        assert_eq!(
2078            db.get_contacts(user2).await.unwrap(),
2079            [
2080                Contact::Accepted {
2081                    user_id: user1,
2082                    should_notify: false
2083                },
2084                Contact::Accepted {
2085                    user_id: user2,
2086                    should_notify: false
2087                }
2088            ]
2089        );
2090
2091        // User 3 redeems the invite code and becomes a contact of user 1.
2092        let user3 = db
2093            .redeem_invite_code(&invite_code, "user-3", None)
2094            .await
2095            .unwrap();
2096        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2097        assert_eq!(invite_count, 0);
2098        assert_eq!(
2099            db.get_contacts(user1).await.unwrap(),
2100            [
2101                Contact::Accepted {
2102                    user_id: user1,
2103                    should_notify: false
2104                },
2105                Contact::Accepted {
2106                    user_id: user2,
2107                    should_notify: true
2108                },
2109                Contact::Accepted {
2110                    user_id: user3,
2111                    should_notify: true
2112                }
2113            ]
2114        );
2115        assert_eq!(
2116            db.get_contacts(user3).await.unwrap(),
2117            [
2118                Contact::Accepted {
2119                    user_id: user1,
2120                    should_notify: false
2121                },
2122                Contact::Accepted {
2123                    user_id: user3,
2124                    should_notify: false
2125                },
2126            ]
2127        );
2128
2129        // Trying to reedem the code for the third time results in an error.
2130        db.redeem_invite_code(&invite_code, "user-4", None)
2131            .await
2132            .unwrap_err();
2133
2134        // Invite count can be updated after the code has been created.
2135        db.set_invite_count(user1, 2).await.unwrap();
2136        let (latest_code, invite_count) =
2137            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2138        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2139        assert_eq!(invite_count, 2);
2140
2141        // User 4 can now redeem the invite code and becomes a contact of user 1.
2142        let user4 = db
2143            .redeem_invite_code(&invite_code, "user-4", None)
2144            .await
2145            .unwrap();
2146        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2147        assert_eq!(invite_count, 1);
2148        assert_eq!(
2149            db.get_contacts(user1).await.unwrap(),
2150            [
2151                Contact::Accepted {
2152                    user_id: user1,
2153                    should_notify: false
2154                },
2155                Contact::Accepted {
2156                    user_id: user2,
2157                    should_notify: true
2158                },
2159                Contact::Accepted {
2160                    user_id: user3,
2161                    should_notify: true
2162                },
2163                Contact::Accepted {
2164                    user_id: user4,
2165                    should_notify: true
2166                }
2167            ]
2168        );
2169        assert_eq!(
2170            db.get_contacts(user4).await.unwrap(),
2171            [
2172                Contact::Accepted {
2173                    user_id: user1,
2174                    should_notify: false
2175                },
2176                Contact::Accepted {
2177                    user_id: user4,
2178                    should_notify: false
2179                },
2180            ]
2181        );
2182
2183        // An existing user cannot redeem invite codes.
2184        db.redeem_invite_code(&invite_code, "user-2", None)
2185            .await
2186            .unwrap_err();
2187        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2188        assert_eq!(invite_count, 1);
2189    }
2190
2191    pub struct TestDb {
2192        pub db: Option<Arc<dyn Db>>,
2193        pub url: String,
2194    }
2195
2196    impl TestDb {
2197        pub async fn postgres() -> Self {
2198            lazy_static! {
2199                static ref LOCK: Mutex<()> = Mutex::new(());
2200            }
2201
2202            let _guard = LOCK.lock();
2203            let mut rng = StdRng::from_entropy();
2204            let name = format!("zed-test-{}", rng.gen::<u128>());
2205            let url = format!("postgres://postgres@localhost/{}", name);
2206            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2207            Postgres::create_database(&url)
2208                .await
2209                .expect("failed to create test db");
2210            let db = PostgresDb::new(&url, 5).await.unwrap();
2211            let migrator = Migrator::new(migrations_path).await.unwrap();
2212            migrator.run(&db.pool).await.unwrap();
2213            Self {
2214                db: Some(Arc::new(db)),
2215                url,
2216            }
2217        }
2218
2219        pub fn fake(background: Arc<Background>) -> Self {
2220            Self {
2221                db: Some(Arc::new(FakeDb::new(background))),
2222                url: Default::default(),
2223            }
2224        }
2225
2226        pub fn db(&self) -> &Arc<dyn Db> {
2227            self.db.as_ref().unwrap()
2228        }
2229    }
2230
2231    impl Drop for TestDb {
2232        fn drop(&mut self) {
2233            if let Some(db) = self.db.take() {
2234                futures::executor::block_on(db.teardown(&self.url));
2235            }
2236        }
2237    }
2238
2239    pub struct FakeDb {
2240        background: Arc<Background>,
2241        pub users: Mutex<BTreeMap<UserId, User>>,
2242        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2243        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2244        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2245        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2246        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2247        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2248        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2249        pub contacts: Mutex<Vec<FakeContact>>,
2250        next_channel_message_id: Mutex<i32>,
2251        next_user_id: Mutex<i32>,
2252        next_org_id: Mutex<i32>,
2253        next_channel_id: Mutex<i32>,
2254        next_project_id: Mutex<i32>,
2255    }
2256
2257    #[derive(Debug)]
2258    pub struct FakeContact {
2259        pub requester_id: UserId,
2260        pub responder_id: UserId,
2261        pub accepted: bool,
2262        pub should_notify: bool,
2263    }
2264
2265    impl FakeDb {
2266        pub fn new(background: Arc<Background>) -> Self {
2267            Self {
2268                background,
2269                users: Default::default(),
2270                next_user_id: Mutex::new(1),
2271                projects: Default::default(),
2272                worktree_extensions: Default::default(),
2273                next_project_id: Mutex::new(1),
2274                orgs: Default::default(),
2275                next_org_id: Mutex::new(1),
2276                org_memberships: Default::default(),
2277                channels: Default::default(),
2278                next_channel_id: Mutex::new(1),
2279                channel_memberships: Default::default(),
2280                channel_messages: Default::default(),
2281                next_channel_message_id: Mutex::new(1),
2282                contacts: Default::default(),
2283            }
2284        }
2285    }
2286
2287    #[async_trait]
2288    impl Db for FakeDb {
2289        async fn create_user(
2290            &self,
2291            github_login: &str,
2292            email_address: Option<&str>,
2293            admin: bool,
2294        ) -> Result<UserId> {
2295            self.background.simulate_random_delay().await;
2296
2297            let mut users = self.users.lock();
2298            if let Some(user) = users
2299                .values()
2300                .find(|user| user.github_login == github_login)
2301            {
2302                Ok(user.id)
2303            } else {
2304                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2305                users.insert(
2306                    user_id,
2307                    User {
2308                        id: user_id,
2309                        github_login: github_login.to_string(),
2310                        email_address: email_address.map(str::to_string),
2311                        admin,
2312                        invite_code: None,
2313                        invite_count: 0,
2314                        connected_once: false,
2315                    },
2316                );
2317                Ok(user_id)
2318            }
2319        }
2320
2321        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2322            unimplemented!()
2323        }
2324
2325        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2326            unimplemented!()
2327        }
2328
2329        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2330            unimplemented!()
2331        }
2332
2333        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2334            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2335        }
2336
2337        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2338            self.background.simulate_random_delay().await;
2339            let users = self.users.lock();
2340            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2341        }
2342
2343        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2344            unimplemented!()
2345        }
2346
2347        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2348            Ok(self
2349                .users
2350                .lock()
2351                .values()
2352                .find(|user| user.github_login == github_login)
2353                .cloned())
2354        }
2355
2356        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2357            unimplemented!()
2358        }
2359
2360        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2361            self.background.simulate_random_delay().await;
2362            let mut users = self.users.lock();
2363            let mut user = users
2364                .get_mut(&id)
2365                .ok_or_else(|| anyhow!("user not found"))?;
2366            user.connected_once = connected_once;
2367            Ok(())
2368        }
2369
2370        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2371            unimplemented!()
2372        }
2373
2374        // invite codes
2375
2376        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2377            unimplemented!()
2378        }
2379
2380        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2381            Ok(None)
2382        }
2383
2384        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2385            unimplemented!()
2386        }
2387
2388        async fn redeem_invite_code(
2389            &self,
2390            _code: &str,
2391            _login: &str,
2392            _email_address: Option<&str>,
2393        ) -> Result<UserId> {
2394            unimplemented!()
2395        }
2396
2397        // projects
2398
2399        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2400            self.background.simulate_random_delay().await;
2401            if !self.users.lock().contains_key(&host_user_id) {
2402                Err(anyhow!("no such user"))?;
2403            }
2404
2405            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2406            self.projects.lock().insert(
2407                project_id,
2408                Project {
2409                    id: project_id,
2410                    host_user_id,
2411                    unregistered: false,
2412                },
2413            );
2414            Ok(project_id)
2415        }
2416
2417        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2418            self.projects
2419                .lock()
2420                .get_mut(&project_id)
2421                .ok_or_else(|| anyhow!("no such project"))?
2422                .unregistered = true;
2423            Ok(())
2424        }
2425
2426        async fn update_worktree_extensions(
2427            &self,
2428            project_id: ProjectId,
2429            worktree_id: u64,
2430            extensions: HashMap<String, usize>,
2431        ) -> Result<()> {
2432            self.background.simulate_random_delay().await;
2433            if !self.projects.lock().contains_key(&project_id) {
2434                Err(anyhow!("no such project"))?;
2435            }
2436
2437            for (extension, count) in extensions {
2438                self.worktree_extensions
2439                    .lock()
2440                    .insert((project_id, worktree_id, extension), count);
2441            }
2442
2443            Ok(())
2444        }
2445
2446        async fn get_project_extensions(
2447            &self,
2448            _project_id: ProjectId,
2449        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2450            unimplemented!()
2451        }
2452
2453        async fn record_project_activity(
2454            &self,
2455            _period: Range<OffsetDateTime>,
2456            _active_projects: &[(UserId, ProjectId)],
2457        ) -> Result<()> {
2458            unimplemented!()
2459        }
2460
2461        async fn summarize_project_activity(
2462            &self,
2463            _period: Range<OffsetDateTime>,
2464            _limit: usize,
2465        ) -> Result<Vec<UserActivitySummary>> {
2466            unimplemented!()
2467        }
2468
2469        async fn summarize_user_activity(
2470            &self,
2471            _user_id: UserId,
2472            _time_period: Range<OffsetDateTime>,
2473        ) -> Result<Vec<UserActivityDuration>> {
2474            unimplemented!()
2475        }
2476
2477        // contacts
2478
2479        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2480            self.background.simulate_random_delay().await;
2481            let mut contacts = vec![Contact::Accepted {
2482                user_id: id,
2483                should_notify: false,
2484            }];
2485
2486            for contact in self.contacts.lock().iter() {
2487                if contact.requester_id == id {
2488                    if contact.accepted {
2489                        contacts.push(Contact::Accepted {
2490                            user_id: contact.responder_id,
2491                            should_notify: contact.should_notify,
2492                        });
2493                    } else {
2494                        contacts.push(Contact::Outgoing {
2495                            user_id: contact.responder_id,
2496                        });
2497                    }
2498                } else if contact.responder_id == id {
2499                    if contact.accepted {
2500                        contacts.push(Contact::Accepted {
2501                            user_id: contact.requester_id,
2502                            should_notify: false,
2503                        });
2504                    } else {
2505                        contacts.push(Contact::Incoming {
2506                            user_id: contact.requester_id,
2507                            should_notify: contact.should_notify,
2508                        });
2509                    }
2510                }
2511            }
2512
2513            contacts.sort_unstable_by_key(|contact| contact.user_id());
2514            Ok(contacts)
2515        }
2516
2517        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2518            self.background.simulate_random_delay().await;
2519            Ok(self.contacts.lock().iter().any(|contact| {
2520                contact.accepted
2521                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2522                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2523            }))
2524        }
2525
2526        async fn send_contact_request(
2527            &self,
2528            requester_id: UserId,
2529            responder_id: UserId,
2530        ) -> Result<()> {
2531            let mut contacts = self.contacts.lock();
2532            for contact in contacts.iter_mut() {
2533                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2534                    if contact.accepted {
2535                        Err(anyhow!("contact already exists"))?;
2536                    } else {
2537                        Err(anyhow!("contact already requested"))?;
2538                    }
2539                }
2540                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2541                    if contact.accepted {
2542                        Err(anyhow!("contact already exists"))?;
2543                    } else {
2544                        contact.accepted = true;
2545                        contact.should_notify = false;
2546                        return Ok(());
2547                    }
2548                }
2549            }
2550            contacts.push(FakeContact {
2551                requester_id,
2552                responder_id,
2553                accepted: false,
2554                should_notify: true,
2555            });
2556            Ok(())
2557        }
2558
2559        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2560            self.contacts.lock().retain(|contact| {
2561                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2562            });
2563            Ok(())
2564        }
2565
2566        async fn dismiss_contact_notification(
2567            &self,
2568            user_id: UserId,
2569            contact_user_id: UserId,
2570        ) -> Result<()> {
2571            let mut contacts = self.contacts.lock();
2572            for contact in contacts.iter_mut() {
2573                if contact.requester_id == contact_user_id
2574                    && contact.responder_id == user_id
2575                    && !contact.accepted
2576                {
2577                    contact.should_notify = false;
2578                    return Ok(());
2579                }
2580                if contact.requester_id == user_id
2581                    && contact.responder_id == contact_user_id
2582                    && contact.accepted
2583                {
2584                    contact.should_notify = false;
2585                    return Ok(());
2586                }
2587            }
2588            Err(anyhow!("no such notification"))?
2589        }
2590
2591        async fn respond_to_contact_request(
2592            &self,
2593            responder_id: UserId,
2594            requester_id: UserId,
2595            accept: bool,
2596        ) -> Result<()> {
2597            let mut contacts = self.contacts.lock();
2598            for (ix, contact) in contacts.iter_mut().enumerate() {
2599                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2600                    if contact.accepted {
2601                        Err(anyhow!("contact already confirmed"))?;
2602                    }
2603                    if accept {
2604                        contact.accepted = true;
2605                        contact.should_notify = true;
2606                    } else {
2607                        contacts.remove(ix);
2608                    }
2609                    return Ok(());
2610                }
2611            }
2612            Err(anyhow!("no such contact request"))?
2613        }
2614
2615        async fn create_access_token_hash(
2616            &self,
2617            _user_id: UserId,
2618            _access_token_hash: &str,
2619            _max_access_token_count: usize,
2620        ) -> Result<()> {
2621            unimplemented!()
2622        }
2623
2624        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2625            unimplemented!()
2626        }
2627
2628        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2629            unimplemented!()
2630        }
2631
2632        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2633            self.background.simulate_random_delay().await;
2634            let mut orgs = self.orgs.lock();
2635            if orgs.values().any(|org| org.slug == slug) {
2636                Err(anyhow!("org already exists"))?
2637            } else {
2638                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2639                orgs.insert(
2640                    org_id,
2641                    Org {
2642                        id: org_id,
2643                        name: name.to_string(),
2644                        slug: slug.to_string(),
2645                    },
2646                );
2647                Ok(org_id)
2648            }
2649        }
2650
2651        async fn add_org_member(
2652            &self,
2653            org_id: OrgId,
2654            user_id: UserId,
2655            is_admin: bool,
2656        ) -> Result<()> {
2657            self.background.simulate_random_delay().await;
2658            if !self.orgs.lock().contains_key(&org_id) {
2659                Err(anyhow!("org does not exist"))?;
2660            }
2661            if !self.users.lock().contains_key(&user_id) {
2662                Err(anyhow!("user does not exist"))?;
2663            }
2664
2665            self.org_memberships
2666                .lock()
2667                .entry((org_id, user_id))
2668                .or_insert(is_admin);
2669            Ok(())
2670        }
2671
2672        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2673            self.background.simulate_random_delay().await;
2674            if !self.orgs.lock().contains_key(&org_id) {
2675                Err(anyhow!("org does not exist"))?;
2676            }
2677
2678            let mut channels = self.channels.lock();
2679            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2680            channels.insert(
2681                channel_id,
2682                Channel {
2683                    id: channel_id,
2684                    name: name.to_string(),
2685                    owner_id: org_id.0,
2686                    owner_is_user: false,
2687                },
2688            );
2689            Ok(channel_id)
2690        }
2691
2692        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2693            self.background.simulate_random_delay().await;
2694            Ok(self
2695                .channels
2696                .lock()
2697                .values()
2698                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2699                .cloned()
2700                .collect())
2701        }
2702
2703        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2704            self.background.simulate_random_delay().await;
2705            let channels = self.channels.lock();
2706            let memberships = self.channel_memberships.lock();
2707            Ok(channels
2708                .values()
2709                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2710                .cloned()
2711                .collect())
2712        }
2713
2714        async fn can_user_access_channel(
2715            &self,
2716            user_id: UserId,
2717            channel_id: ChannelId,
2718        ) -> Result<bool> {
2719            self.background.simulate_random_delay().await;
2720            Ok(self
2721                .channel_memberships
2722                .lock()
2723                .contains_key(&(channel_id, user_id)))
2724        }
2725
2726        async fn add_channel_member(
2727            &self,
2728            channel_id: ChannelId,
2729            user_id: UserId,
2730            is_admin: bool,
2731        ) -> Result<()> {
2732            self.background.simulate_random_delay().await;
2733            if !self.channels.lock().contains_key(&channel_id) {
2734                Err(anyhow!("channel does not exist"))?;
2735            }
2736            if !self.users.lock().contains_key(&user_id) {
2737                Err(anyhow!("user does not exist"))?;
2738            }
2739
2740            self.channel_memberships
2741                .lock()
2742                .entry((channel_id, user_id))
2743                .or_insert(is_admin);
2744            Ok(())
2745        }
2746
2747        async fn create_channel_message(
2748            &self,
2749            channel_id: ChannelId,
2750            sender_id: UserId,
2751            body: &str,
2752            timestamp: OffsetDateTime,
2753            nonce: u128,
2754        ) -> Result<MessageId> {
2755            self.background.simulate_random_delay().await;
2756            if !self.channels.lock().contains_key(&channel_id) {
2757                Err(anyhow!("channel does not exist"))?;
2758            }
2759            if !self.users.lock().contains_key(&sender_id) {
2760                Err(anyhow!("user does not exist"))?;
2761            }
2762
2763            let mut messages = self.channel_messages.lock();
2764            if let Some(message) = messages
2765                .values()
2766                .find(|message| message.nonce.as_u128() == nonce)
2767            {
2768                Ok(message.id)
2769            } else {
2770                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2771                messages.insert(
2772                    message_id,
2773                    ChannelMessage {
2774                        id: message_id,
2775                        channel_id,
2776                        sender_id,
2777                        body: body.to_string(),
2778                        sent_at: timestamp,
2779                        nonce: Uuid::from_u128(nonce),
2780                    },
2781                );
2782                Ok(message_id)
2783            }
2784        }
2785
2786        async fn get_channel_messages(
2787            &self,
2788            channel_id: ChannelId,
2789            count: usize,
2790            before_id: Option<MessageId>,
2791        ) -> Result<Vec<ChannelMessage>> {
2792            let mut messages = self
2793                .channel_messages
2794                .lock()
2795                .values()
2796                .rev()
2797                .filter(|message| {
2798                    message.channel_id == channel_id
2799                        && message.id < before_id.unwrap_or(MessageId::MAX)
2800                })
2801                .take(count)
2802                .cloned()
2803                .collect::<Vec<_>>();
2804            messages.sort_unstable_by_key(|message| message.id);
2805            Ok(messages)
2806        }
2807
2808        async fn teardown(&self, _: &str) {}
2809
2810        #[cfg(test)]
2811        fn as_fake(&self) -> Option<&FakeDb> {
2812            Some(self)
2813        }
2814    }
2815}