db.rs

   1use std::{ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use nanoid::nanoid;
  10use serde::{Deserialize, Serialize};
  11pub use sqlx::postgres::PgPoolOptions as DbOptions;
  12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[async_trait]
  16pub trait Db: Send + Sync {
  17    async fn create_user(
  18        &self,
  19        github_login: &str,
  20        email_address: Option<&str>,
  21        admin: bool,
  22    ) -> Result<UserId>;
  23    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  24    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  28    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  29    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  30    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  31    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33
  34    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  35    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  36    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  37    async fn redeem_invite_code(
  38        &self,
  39        code: &str,
  40        login: &str,
  41        email_address: Option<&str>,
  42    ) -> Result<UserId>;
  43
  44    /// Registers a new project for the given user.
  45    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  46
  47    /// Unregisters a project for the given project id.
  48    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  49
  50    /// Update file counts by extension for the given project and worktree.
  51    async fn update_worktree_extensions(
  52        &self,
  53        project_id: ProjectId,
  54        worktree_id: u64,
  55        extensions: HashMap<String, usize>,
  56    ) -> Result<()>;
  57
  58    /// Get the file counts on the given project keyed by their worktree and extension.
  59    async fn get_project_extensions(
  60        &self,
  61        project_id: ProjectId,
  62    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  63
  64    /// Record which users have been active in which projects during
  65    /// a given period of time.
  66    async fn record_user_activity(
  67        &self,
  68        time_period: Range<OffsetDateTime>,
  69        active_projects: &[(UserId, ProjectId)],
  70    ) -> Result<()>;
  71
  72    /// Get the users that have been most active during the given time period,
  73    /// along with the amount of time they have been active in each project.
  74    async fn get_top_users_activity_summary(
  75        &self,
  76        time_period: Range<OffsetDateTime>,
  77        max_user_count: usize,
  78    ) -> Result<Vec<UserActivitySummary>>;
  79
  80    /// Get the project activity for the given user and time period.
  81    async fn get_user_activity_timeline(
  82        &self,
  83        time_period: Range<OffsetDateTime>,
  84        user_id: UserId,
  85    ) -> Result<Vec<UserActivityPeriod>>;
  86
  87    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  88    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  89    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  90    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  91    async fn dismiss_contact_notification(
  92        &self,
  93        responder_id: UserId,
  94        requester_id: UserId,
  95    ) -> Result<()>;
  96    async fn respond_to_contact_request(
  97        &self,
  98        responder_id: UserId,
  99        requester_id: UserId,
 100        accept: bool,
 101    ) -> Result<()>;
 102
 103    async fn create_access_token_hash(
 104        &self,
 105        user_id: UserId,
 106        access_token_hash: &str,
 107        max_access_token_count: usize,
 108    ) -> Result<()>;
 109    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 110    #[cfg(any(test, feature = "seed-support"))]
 111
 112    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 113    #[cfg(any(test, feature = "seed-support"))]
 114    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 115    #[cfg(any(test, feature = "seed-support"))]
 116    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 117    #[cfg(any(test, feature = "seed-support"))]
 118    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 119    #[cfg(any(test, feature = "seed-support"))]
 120
 121    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 122    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 123    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 124        -> Result<bool>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn add_channel_member(
 127        &self,
 128        channel_id: ChannelId,
 129        user_id: UserId,
 130        is_admin: bool,
 131    ) -> Result<()>;
 132    async fn create_channel_message(
 133        &self,
 134        channel_id: ChannelId,
 135        sender_id: UserId,
 136        body: &str,
 137        timestamp: OffsetDateTime,
 138        nonce: u128,
 139    ) -> Result<MessageId>;
 140    async fn get_channel_messages(
 141        &self,
 142        channel_id: ChannelId,
 143        count: usize,
 144        before_id: Option<MessageId>,
 145    ) -> Result<Vec<ChannelMessage>>;
 146    #[cfg(test)]
 147    async fn teardown(&self, url: &str);
 148    #[cfg(test)]
 149    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 150}
 151
 152pub struct PostgresDb {
 153    pool: sqlx::PgPool,
 154}
 155
 156impl PostgresDb {
 157    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 158        let pool = DbOptions::new()
 159            .max_connections(max_connections)
 160            .connect(&url)
 161            .await
 162            .context("failed to connect to postgres database")?;
 163        Ok(Self { pool })
 164    }
 165}
 166
 167#[async_trait]
 168impl Db for PostgresDb {
 169    // users
 170
 171    async fn create_user(
 172        &self,
 173        github_login: &str,
 174        email_address: Option<&str>,
 175        admin: bool,
 176    ) -> Result<UserId> {
 177        let query = "
 178            INSERT INTO users (github_login, email_address, admin)
 179            VALUES ($1, $2, $3)
 180            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 181            RETURNING id
 182        ";
 183        Ok(sqlx::query_scalar(query)
 184            .bind(github_login)
 185            .bind(email_address)
 186            .bind(admin)
 187            .fetch_one(&self.pool)
 188            .await
 189            .map(UserId)?)
 190    }
 191
 192    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 193        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 194        Ok(sqlx::query_as(query)
 195            .bind(limit as i32)
 196            .bind((page * limit) as i32)
 197            .fetch_all(&self.pool)
 198            .await?)
 199    }
 200
 201    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 202        let mut query = QueryBuilder::new(
 203            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 204        );
 205        query.push_values(
 206            users,
 207            |mut query, (github_login, email_address, invite_count)| {
 208                query
 209                    .push_bind(github_login)
 210                    .push_bind(email_address)
 211                    .push_bind(false)
 212                    .push_bind(nanoid!(16))
 213                    .push_bind(invite_count as i32);
 214            },
 215        );
 216        query.push(
 217            "
 218            ON CONFLICT (github_login) DO UPDATE SET
 219                github_login = excluded.github_login,
 220                invite_count = excluded.invite_count,
 221                invite_code = CASE WHEN users.invite_code IS NULL
 222                                   THEN excluded.invite_code
 223                                   ELSE users.invite_code
 224                              END
 225            RETURNING id
 226            ",
 227        );
 228
 229        let rows = query.build().fetch_all(&self.pool).await?;
 230        Ok(rows
 231            .into_iter()
 232            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 233            .collect())
 234    }
 235
 236    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 237        let like_string = fuzzy_like_string(name_query);
 238        let query = "
 239            SELECT users.*
 240            FROM users
 241            WHERE github_login ILIKE $1
 242            ORDER BY github_login <-> $2
 243            LIMIT $3
 244        ";
 245        Ok(sqlx::query_as(query)
 246            .bind(like_string)
 247            .bind(name_query)
 248            .bind(limit as i32)
 249            .fetch_all(&self.pool)
 250            .await?)
 251    }
 252
 253    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 254        let users = self.get_users_by_ids(vec![id]).await?;
 255        Ok(users.into_iter().next())
 256    }
 257
 258    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 259        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 260        let query = "
 261            SELECT users.*
 262            FROM users
 263            WHERE users.id = ANY ($1)
 264        ";
 265        Ok(sqlx::query_as(query)
 266            .bind(&ids)
 267            .fetch_all(&self.pool)
 268            .await?)
 269    }
 270
 271    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 272        let query = format!(
 273            "
 274            SELECT users.*
 275            FROM users
 276            WHERE invite_count = 0
 277            AND inviter_id IS{} NULL
 278            ",
 279            if invited_by_another_user { " NOT" } else { "" }
 280        );
 281
 282        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 283    }
 284
 285    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 286        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 287        Ok(sqlx::query_as(query)
 288            .bind(github_login)
 289            .fetch_optional(&self.pool)
 290            .await?)
 291    }
 292
 293    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 294        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 295        Ok(sqlx::query(query)
 296            .bind(is_admin)
 297            .bind(id.0)
 298            .execute(&self.pool)
 299            .await
 300            .map(drop)?)
 301    }
 302
 303    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 304        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 305        Ok(sqlx::query(query)
 306            .bind(connected_once)
 307            .bind(id.0)
 308            .execute(&self.pool)
 309            .await
 310            .map(drop)?)
 311    }
 312
 313    async fn destroy_user(&self, id: UserId) -> Result<()> {
 314        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 315        sqlx::query(query)
 316            .bind(id.0)
 317            .execute(&self.pool)
 318            .await
 319            .map(drop)?;
 320        let query = "DELETE FROM users WHERE id = $1;";
 321        Ok(sqlx::query(query)
 322            .bind(id.0)
 323            .execute(&self.pool)
 324            .await
 325            .map(drop)?)
 326    }
 327
 328    // invite codes
 329
 330    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 331        let mut tx = self.pool.begin().await?;
 332        if count > 0 {
 333            sqlx::query(
 334                "
 335                UPDATE users
 336                SET invite_code = $1
 337                WHERE id = $2 AND invite_code IS NULL
 338            ",
 339            )
 340            .bind(nanoid!(16))
 341            .bind(id)
 342            .execute(&mut tx)
 343            .await?;
 344        }
 345
 346        sqlx::query(
 347            "
 348            UPDATE users
 349            SET invite_count = $1
 350            WHERE id = $2
 351            ",
 352        )
 353        .bind(count as i32)
 354        .bind(id)
 355        .execute(&mut tx)
 356        .await?;
 357        tx.commit().await?;
 358        Ok(())
 359    }
 360
 361    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 362        let result: Option<(String, i32)> = sqlx::query_as(
 363            "
 364                SELECT invite_code, invite_count
 365                FROM users
 366                WHERE id = $1 AND invite_code IS NOT NULL 
 367            ",
 368        )
 369        .bind(id)
 370        .fetch_optional(&self.pool)
 371        .await?;
 372        if let Some((code, count)) = result {
 373            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 374        } else {
 375            Ok(None)
 376        }
 377    }
 378
 379    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 380        sqlx::query_as(
 381            "
 382                SELECT *
 383                FROM users
 384                WHERE invite_code = $1
 385            ",
 386        )
 387        .bind(code)
 388        .fetch_optional(&self.pool)
 389        .await?
 390        .ok_or_else(|| {
 391            Error::Http(
 392                StatusCode::NOT_FOUND,
 393                "that invite code does not exist".to_string(),
 394            )
 395        })
 396    }
 397
 398    async fn redeem_invite_code(
 399        &self,
 400        code: &str,
 401        login: &str,
 402        email_address: Option<&str>,
 403    ) -> Result<UserId> {
 404        let mut tx = self.pool.begin().await?;
 405
 406        let inviter_id: Option<UserId> = sqlx::query_scalar(
 407            "
 408                UPDATE users
 409                SET invite_count = invite_count - 1
 410                WHERE
 411                    invite_code = $1 AND
 412                    invite_count > 0
 413                RETURNING id
 414            ",
 415        )
 416        .bind(code)
 417        .fetch_optional(&mut tx)
 418        .await?;
 419
 420        let inviter_id = match inviter_id {
 421            Some(inviter_id) => inviter_id,
 422            None => {
 423                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 424                    .bind(code)
 425                    .fetch_optional(&mut tx)
 426                    .await?
 427                    .is_some()
 428                {
 429                    Err(Error::Http(
 430                        StatusCode::UNAUTHORIZED,
 431                        "no invites remaining".to_string(),
 432                    ))?
 433                } else {
 434                    Err(Error::Http(
 435                        StatusCode::NOT_FOUND,
 436                        "invite code not found".to_string(),
 437                    ))?
 438                }
 439            }
 440        };
 441
 442        let invitee_id = sqlx::query_scalar(
 443            "
 444                INSERT INTO users
 445                    (github_login, email_address, admin, inviter_id)
 446                VALUES
 447                    ($1, $2, 'f', $3)
 448                RETURNING id
 449            ",
 450        )
 451        .bind(login)
 452        .bind(email_address)
 453        .bind(inviter_id)
 454        .fetch_one(&mut tx)
 455        .await
 456        .map(UserId)?;
 457
 458        sqlx::query(
 459            "
 460                INSERT INTO contacts
 461                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 462                VALUES
 463                    ($1, $2, 't', 't', 't')
 464            ",
 465        )
 466        .bind(inviter_id)
 467        .bind(invitee_id)
 468        .execute(&mut tx)
 469        .await?;
 470
 471        tx.commit().await?;
 472        Ok(invitee_id)
 473    }
 474
 475    // projects
 476
 477    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 478        Ok(sqlx::query_scalar(
 479            "
 480            INSERT INTO projects(host_user_id)
 481            VALUES ($1)
 482            RETURNING id
 483            ",
 484        )
 485        .bind(host_user_id)
 486        .fetch_one(&self.pool)
 487        .await
 488        .map(ProjectId)?)
 489    }
 490
 491    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 492        sqlx::query(
 493            "
 494            UPDATE projects
 495            SET unregistered = 't'
 496            WHERE id = $1
 497            ",
 498        )
 499        .bind(project_id)
 500        .execute(&self.pool)
 501        .await?;
 502        Ok(())
 503    }
 504
 505    async fn update_worktree_extensions(
 506        &self,
 507        project_id: ProjectId,
 508        worktree_id: u64,
 509        extensions: HashMap<String, usize>,
 510    ) -> Result<()> {
 511        if extensions.is_empty() {
 512            return Ok(());
 513        }
 514
 515        let mut query = QueryBuilder::new(
 516            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 517        );
 518        query.push_values(extensions, |mut query, (extension, count)| {
 519            query
 520                .push_bind(project_id)
 521                .push_bind(worktree_id as i32)
 522                .push_bind(extension)
 523                .push_bind(count as i32);
 524        });
 525        query.push(
 526            "
 527            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 528            count = excluded.count
 529            ",
 530        );
 531        query.build().execute(&self.pool).await?;
 532
 533        Ok(())
 534    }
 535
 536    async fn get_project_extensions(
 537        &self,
 538        project_id: ProjectId,
 539    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 540        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 541        struct WorktreeExtension {
 542            worktree_id: i32,
 543            extension: String,
 544            count: i32,
 545        }
 546
 547        let query = "
 548            SELECT worktree_id, extension, count
 549            FROM worktree_extensions
 550            WHERE project_id = $1
 551        ";
 552        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 553            .bind(&project_id)
 554            .fetch_all(&self.pool)
 555            .await?;
 556
 557        let mut extension_counts = HashMap::default();
 558        for count in counts {
 559            extension_counts
 560                .entry(count.worktree_id as u64)
 561                .or_insert(HashMap::default())
 562                .insert(count.extension, count.count as usize);
 563        }
 564        Ok(extension_counts)
 565    }
 566
 567    async fn record_user_activity(
 568        &self,
 569        time_period: Range<OffsetDateTime>,
 570        projects: &[(UserId, ProjectId)],
 571    ) -> Result<()> {
 572        let query = "
 573            INSERT INTO project_activity_periods
 574            (ended_at, duration_millis, user_id, project_id)
 575            VALUES
 576            ($1, $2, $3, $4);
 577        ";
 578
 579        let mut tx = self.pool.begin().await?;
 580        let duration_millis =
 581            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 582        for (user_id, project_id) in projects {
 583            sqlx::query(query)
 584                .bind(time_period.end)
 585                .bind(duration_millis)
 586                .bind(user_id)
 587                .bind(project_id)
 588                .execute(&mut tx)
 589                .await?;
 590        }
 591        tx.commit().await?;
 592
 593        Ok(())
 594    }
 595
 596    async fn get_top_users_activity_summary(
 597        &self,
 598        time_period: Range<OffsetDateTime>,
 599        max_user_count: usize,
 600    ) -> Result<Vec<UserActivitySummary>> {
 601        let query = "
 602            WITH
 603                project_durations AS (
 604                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 605                    FROM project_activity_periods
 606                    WHERE $1 < ended_at AND ended_at <= $2
 607                    GROUP BY user_id, project_id
 608                ),
 609                user_durations AS (
 610                    SELECT user_id, SUM(project_duration) as total_duration
 611                    FROM project_durations
 612                    GROUP BY user_id
 613                    ORDER BY total_duration DESC
 614                    LIMIT $3
 615                )
 616            SELECT user_durations.user_id, users.github_login, project_id, project_duration
 617            FROM user_durations, project_durations, users
 618            WHERE
 619                user_durations.user_id = project_durations.user_id AND
 620                user_durations.user_id = users.id
 621            ORDER BY user_id ASC, project_duration DESC
 622        ";
 623
 624        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
 625            .bind(time_period.start)
 626            .bind(time_period.end)
 627            .bind(max_user_count as i32)
 628            .fetch(&self.pool);
 629
 630        let mut result = Vec::<UserActivitySummary>::new();
 631        while let Some(row) = rows.next().await {
 632            let (user_id, github_login, project_id, duration_millis) = row?;
 633            let project_id = project_id;
 634            let duration = Duration::from_millis(duration_millis as u64);
 635            if let Some(last_summary) = result.last_mut() {
 636                if last_summary.id == user_id {
 637                    last_summary.project_activity.push((project_id, duration));
 638                    continue;
 639                }
 640            }
 641            result.push(UserActivitySummary {
 642                id: user_id,
 643                project_activity: vec![(project_id, duration)],
 644                github_login,
 645            });
 646        }
 647
 648        Ok(result)
 649    }
 650
 651    async fn get_user_activity_timeline(
 652        &self,
 653        time_period: Range<OffsetDateTime>,
 654        user_id: UserId,
 655    ) -> Result<Vec<UserActivityPeriod>> {
 656        const COALESCE_THRESHOLD: Duration = Duration::from_secs(5);
 657
 658        let query = "
 659            SELECT
 660                project_activity_periods.ended_at,
 661                project_activity_periods.duration_millis,
 662                project_activity_periods.project_id,
 663                worktree_extensions.extension,
 664                worktree_extensions.count
 665            FROM project_activity_periods
 666            LEFT OUTER JOIN
 667                worktree_extensions
 668            ON
 669                project_activity_periods.project_id = worktree_extensions.project_id
 670            WHERE
 671                project_activity_periods.user_id = $1 AND
 672                $2 < project_activity_periods.ended_at AND
 673                project_activity_periods.ended_at <= $3
 674            ORDER BY project_activity_periods.id ASC
 675        ";
 676
 677        let mut rows = sqlx::query_as::<
 678            _,
 679            (
 680                PrimitiveDateTime,
 681                i32,
 682                ProjectId,
 683                Option<String>,
 684                Option<i32>,
 685            ),
 686        >(query)
 687        .bind(user_id)
 688        .bind(time_period.start)
 689        .bind(time_period.end)
 690        .fetch(&self.pool);
 691
 692        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 693        while let Some(row) = rows.next().await {
 694            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 695            let ended_at = ended_at.assume_utc();
 696            let duration = Duration::from_millis(duration_millis as u64);
 697            let started_at = ended_at - duration;
 698            let project_time_periods = time_periods.entry(project_id).or_default();
 699
 700            if let Some(prev_duration) = project_time_periods.last_mut() {
 701                if started_at - prev_duration.end <= COALESCE_THRESHOLD {
 702                    prev_duration.end = ended_at;
 703                } else {
 704                    project_time_periods.push(UserActivityPeriod {
 705                        project_id,
 706                        start: started_at,
 707                        end: ended_at,
 708                        extensions: Default::default(),
 709                    });
 710                }
 711            } else {
 712                project_time_periods.push(UserActivityPeriod {
 713                    project_id,
 714                    start: started_at,
 715                    end: ended_at,
 716                    extensions: Default::default(),
 717                });
 718            }
 719
 720            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 721                project_time_periods
 722                    .last_mut()
 723                    .unwrap()
 724                    .extensions
 725                    .insert(extension, extension_count as usize);
 726            }
 727        }
 728
 729        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 730        durations.sort_unstable_by_key(|duration| duration.start);
 731        Ok(durations)
 732    }
 733
 734    // contacts
 735
 736    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 737        let query = "
 738            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 739            FROM contacts
 740            WHERE user_id_a = $1 OR user_id_b = $1;
 741        ";
 742
 743        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 744            .bind(user_id)
 745            .fetch(&self.pool);
 746
 747        let mut contacts = vec![Contact::Accepted {
 748            user_id,
 749            should_notify: false,
 750        }];
 751        while let Some(row) = rows.next().await {
 752            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 753
 754            if user_id_a == user_id {
 755                if accepted {
 756                    contacts.push(Contact::Accepted {
 757                        user_id: user_id_b,
 758                        should_notify: should_notify && a_to_b,
 759                    });
 760                } else if a_to_b {
 761                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 762                } else {
 763                    contacts.push(Contact::Incoming {
 764                        user_id: user_id_b,
 765                        should_notify,
 766                    });
 767                }
 768            } else {
 769                if accepted {
 770                    contacts.push(Contact::Accepted {
 771                        user_id: user_id_a,
 772                        should_notify: should_notify && !a_to_b,
 773                    });
 774                } else if a_to_b {
 775                    contacts.push(Contact::Incoming {
 776                        user_id: user_id_a,
 777                        should_notify,
 778                    });
 779                } else {
 780                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 781                }
 782            }
 783        }
 784
 785        contacts.sort_unstable_by_key(|contact| contact.user_id());
 786
 787        Ok(contacts)
 788    }
 789
 790    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 791        let (id_a, id_b) = if user_id_1 < user_id_2 {
 792            (user_id_1, user_id_2)
 793        } else {
 794            (user_id_2, user_id_1)
 795        };
 796
 797        let query = "
 798            SELECT 1 FROM contacts
 799            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 800            LIMIT 1
 801        ";
 802        Ok(sqlx::query_scalar::<_, i32>(query)
 803            .bind(id_a.0)
 804            .bind(id_b.0)
 805            .fetch_optional(&self.pool)
 806            .await?
 807            .is_some())
 808    }
 809
 810    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 811        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 812            (sender_id, receiver_id, true)
 813        } else {
 814            (receiver_id, sender_id, false)
 815        };
 816        let query = "
 817            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 818            VALUES ($1, $2, $3, 'f', 't')
 819            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 820            SET
 821                accepted = 't',
 822                should_notify = 'f'
 823            WHERE
 824                NOT contacts.accepted AND
 825                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 826                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 827        ";
 828        let result = sqlx::query(query)
 829            .bind(id_a.0)
 830            .bind(id_b.0)
 831            .bind(a_to_b)
 832            .execute(&self.pool)
 833            .await?;
 834
 835        if result.rows_affected() == 1 {
 836            Ok(())
 837        } else {
 838            Err(anyhow!("contact already requested"))?
 839        }
 840    }
 841
 842    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 843        let (id_a, id_b) = if responder_id < requester_id {
 844            (responder_id, requester_id)
 845        } else {
 846            (requester_id, responder_id)
 847        };
 848        let query = "
 849            DELETE FROM contacts
 850            WHERE user_id_a = $1 AND user_id_b = $2;
 851        ";
 852        let result = sqlx::query(query)
 853            .bind(id_a.0)
 854            .bind(id_b.0)
 855            .execute(&self.pool)
 856            .await?;
 857
 858        if result.rows_affected() == 1 {
 859            Ok(())
 860        } else {
 861            Err(anyhow!("no such contact"))?
 862        }
 863    }
 864
 865    async fn dismiss_contact_notification(
 866        &self,
 867        user_id: UserId,
 868        contact_user_id: UserId,
 869    ) -> Result<()> {
 870        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 871            (user_id, contact_user_id, true)
 872        } else {
 873            (contact_user_id, user_id, false)
 874        };
 875
 876        let query = "
 877            UPDATE contacts
 878            SET should_notify = 'f'
 879            WHERE
 880                user_id_a = $1 AND user_id_b = $2 AND
 881                (
 882                    (a_to_b = $3 AND accepted) OR
 883                    (a_to_b != $3 AND NOT accepted)
 884                );
 885        ";
 886
 887        let result = sqlx::query(query)
 888            .bind(id_a.0)
 889            .bind(id_b.0)
 890            .bind(a_to_b)
 891            .execute(&self.pool)
 892            .await?;
 893
 894        if result.rows_affected() == 0 {
 895            Err(anyhow!("no such contact request"))?;
 896        }
 897
 898        Ok(())
 899    }
 900
 901    async fn respond_to_contact_request(
 902        &self,
 903        responder_id: UserId,
 904        requester_id: UserId,
 905        accept: bool,
 906    ) -> Result<()> {
 907        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 908            (responder_id, requester_id, false)
 909        } else {
 910            (requester_id, responder_id, true)
 911        };
 912        let result = if accept {
 913            let query = "
 914                UPDATE contacts
 915                SET accepted = 't', should_notify = 't'
 916                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 917            ";
 918            sqlx::query(query)
 919                .bind(id_a.0)
 920                .bind(id_b.0)
 921                .bind(a_to_b)
 922                .execute(&self.pool)
 923                .await?
 924        } else {
 925            let query = "
 926                DELETE FROM contacts
 927                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 928            ";
 929            sqlx::query(query)
 930                .bind(id_a.0)
 931                .bind(id_b.0)
 932                .bind(a_to_b)
 933                .execute(&self.pool)
 934                .await?
 935        };
 936        if result.rows_affected() == 1 {
 937            Ok(())
 938        } else {
 939            Err(anyhow!("no such contact request"))?
 940        }
 941    }
 942
 943    // access tokens
 944
 945    async fn create_access_token_hash(
 946        &self,
 947        user_id: UserId,
 948        access_token_hash: &str,
 949        max_access_token_count: usize,
 950    ) -> Result<()> {
 951        let insert_query = "
 952            INSERT INTO access_tokens (user_id, hash)
 953            VALUES ($1, $2);
 954        ";
 955        let cleanup_query = "
 956            DELETE FROM access_tokens
 957            WHERE id IN (
 958                SELECT id from access_tokens
 959                WHERE user_id = $1
 960                ORDER BY id DESC
 961                OFFSET $3
 962            )
 963        ";
 964
 965        let mut tx = self.pool.begin().await?;
 966        sqlx::query(insert_query)
 967            .bind(user_id.0)
 968            .bind(access_token_hash)
 969            .execute(&mut tx)
 970            .await?;
 971        sqlx::query(cleanup_query)
 972            .bind(user_id.0)
 973            .bind(access_token_hash)
 974            .bind(max_access_token_count as i32)
 975            .execute(&mut tx)
 976            .await?;
 977        Ok(tx.commit().await?)
 978    }
 979
 980    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 981        let query = "
 982            SELECT hash
 983            FROM access_tokens
 984            WHERE user_id = $1
 985            ORDER BY id DESC
 986        ";
 987        Ok(sqlx::query_scalar(query)
 988            .bind(user_id.0)
 989            .fetch_all(&self.pool)
 990            .await?)
 991    }
 992
 993    // orgs
 994
 995    #[allow(unused)] // Help rust-analyzer
 996    #[cfg(any(test, feature = "seed-support"))]
 997    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
 998        let query = "
 999            SELECT *
1000            FROM orgs
1001            WHERE slug = $1
1002        ";
1003        Ok(sqlx::query_as(query)
1004            .bind(slug)
1005            .fetch_optional(&self.pool)
1006            .await?)
1007    }
1008
1009    #[cfg(any(test, feature = "seed-support"))]
1010    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1011        let query = "
1012            INSERT INTO orgs (name, slug)
1013            VALUES ($1, $2)
1014            RETURNING id
1015        ";
1016        Ok(sqlx::query_scalar(query)
1017            .bind(name)
1018            .bind(slug)
1019            .fetch_one(&self.pool)
1020            .await
1021            .map(OrgId)?)
1022    }
1023
1024    #[cfg(any(test, feature = "seed-support"))]
1025    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1026        let query = "
1027            INSERT INTO org_memberships (org_id, user_id, admin)
1028            VALUES ($1, $2, $3)
1029            ON CONFLICT DO NOTHING
1030        ";
1031        Ok(sqlx::query(query)
1032            .bind(org_id.0)
1033            .bind(user_id.0)
1034            .bind(is_admin)
1035            .execute(&self.pool)
1036            .await
1037            .map(drop)?)
1038    }
1039
1040    // channels
1041
1042    #[cfg(any(test, feature = "seed-support"))]
1043    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1044        let query = "
1045            INSERT INTO channels (owner_id, owner_is_user, name)
1046            VALUES ($1, false, $2)
1047            RETURNING id
1048        ";
1049        Ok(sqlx::query_scalar(query)
1050            .bind(org_id.0)
1051            .bind(name)
1052            .fetch_one(&self.pool)
1053            .await
1054            .map(ChannelId)?)
1055    }
1056
1057    #[allow(unused)] // Help rust-analyzer
1058    #[cfg(any(test, feature = "seed-support"))]
1059    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1060        let query = "
1061            SELECT *
1062            FROM channels
1063            WHERE
1064                channels.owner_is_user = false AND
1065                channels.owner_id = $1
1066        ";
1067        Ok(sqlx::query_as(query)
1068            .bind(org_id.0)
1069            .fetch_all(&self.pool)
1070            .await?)
1071    }
1072
1073    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1074        let query = "
1075            SELECT
1076                channels.*
1077            FROM
1078                channel_memberships, channels
1079            WHERE
1080                channel_memberships.user_id = $1 AND
1081                channel_memberships.channel_id = channels.id
1082        ";
1083        Ok(sqlx::query_as(query)
1084            .bind(user_id.0)
1085            .fetch_all(&self.pool)
1086            .await?)
1087    }
1088
1089    async fn can_user_access_channel(
1090        &self,
1091        user_id: UserId,
1092        channel_id: ChannelId,
1093    ) -> Result<bool> {
1094        let query = "
1095            SELECT id
1096            FROM channel_memberships
1097            WHERE user_id = $1 AND channel_id = $2
1098            LIMIT 1
1099        ";
1100        Ok(sqlx::query_scalar::<_, i32>(query)
1101            .bind(user_id.0)
1102            .bind(channel_id.0)
1103            .fetch_optional(&self.pool)
1104            .await
1105            .map(|e| e.is_some())?)
1106    }
1107
1108    #[cfg(any(test, feature = "seed-support"))]
1109    async fn add_channel_member(
1110        &self,
1111        channel_id: ChannelId,
1112        user_id: UserId,
1113        is_admin: bool,
1114    ) -> Result<()> {
1115        let query = "
1116            INSERT INTO channel_memberships (channel_id, user_id, admin)
1117            VALUES ($1, $2, $3)
1118            ON CONFLICT DO NOTHING
1119        ";
1120        Ok(sqlx::query(query)
1121            .bind(channel_id.0)
1122            .bind(user_id.0)
1123            .bind(is_admin)
1124            .execute(&self.pool)
1125            .await
1126            .map(drop)?)
1127    }
1128
1129    // messages
1130
1131    async fn create_channel_message(
1132        &self,
1133        channel_id: ChannelId,
1134        sender_id: UserId,
1135        body: &str,
1136        timestamp: OffsetDateTime,
1137        nonce: u128,
1138    ) -> Result<MessageId> {
1139        let query = "
1140            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1141            VALUES ($1, $2, $3, $4, $5)
1142            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1143            RETURNING id
1144        ";
1145        Ok(sqlx::query_scalar(query)
1146            .bind(channel_id.0)
1147            .bind(sender_id.0)
1148            .bind(body)
1149            .bind(timestamp)
1150            .bind(Uuid::from_u128(nonce))
1151            .fetch_one(&self.pool)
1152            .await
1153            .map(MessageId)?)
1154    }
1155
1156    async fn get_channel_messages(
1157        &self,
1158        channel_id: ChannelId,
1159        count: usize,
1160        before_id: Option<MessageId>,
1161    ) -> Result<Vec<ChannelMessage>> {
1162        let query = r#"
1163            SELECT * FROM (
1164                SELECT
1165                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1166                FROM
1167                    channel_messages
1168                WHERE
1169                    channel_id = $1 AND
1170                    id < $2
1171                ORDER BY id DESC
1172                LIMIT $3
1173            ) as recent_messages
1174            ORDER BY id ASC
1175        "#;
1176        Ok(sqlx::query_as(query)
1177            .bind(channel_id.0)
1178            .bind(before_id.unwrap_or(MessageId::MAX))
1179            .bind(count as i64)
1180            .fetch_all(&self.pool)
1181            .await?)
1182    }
1183
1184    #[cfg(test)]
1185    async fn teardown(&self, url: &str) {
1186        use util::ResultExt;
1187
1188        let query = "
1189            SELECT pg_terminate_backend(pg_stat_activity.pid)
1190            FROM pg_stat_activity
1191            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1192        ";
1193        sqlx::query(query).execute(&self.pool).await.log_err();
1194        self.pool.close().await;
1195        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1196            .await
1197            .log_err();
1198    }
1199
1200    #[cfg(test)]
1201    fn as_fake(&self) -> Option<&tests::FakeDb> {
1202        None
1203    }
1204}
1205
1206macro_rules! id_type {
1207    ($name:ident) => {
1208        #[derive(
1209            Clone,
1210            Copy,
1211            Debug,
1212            Default,
1213            PartialEq,
1214            Eq,
1215            PartialOrd,
1216            Ord,
1217            Hash,
1218            sqlx::Type,
1219            Serialize,
1220            Deserialize,
1221        )]
1222        #[sqlx(transparent)]
1223        #[serde(transparent)]
1224        pub struct $name(pub i32);
1225
1226        impl $name {
1227            #[allow(unused)]
1228            pub const MAX: Self = Self(i32::MAX);
1229
1230            #[allow(unused)]
1231            pub fn from_proto(value: u64) -> Self {
1232                Self(value as i32)
1233            }
1234
1235            #[allow(unused)]
1236            pub fn to_proto(&self) -> u64 {
1237                self.0 as u64
1238            }
1239        }
1240
1241        impl std::fmt::Display for $name {
1242            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1243                self.0.fmt(f)
1244            }
1245        }
1246    };
1247}
1248
1249id_type!(UserId);
1250#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1251pub struct User {
1252    pub id: UserId,
1253    pub github_login: String,
1254    pub email_address: Option<String>,
1255    pub admin: bool,
1256    pub invite_code: Option<String>,
1257    pub invite_count: i32,
1258    pub connected_once: bool,
1259}
1260
1261id_type!(ProjectId);
1262#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1263pub struct Project {
1264    pub id: ProjectId,
1265    pub host_user_id: UserId,
1266    pub unregistered: bool,
1267}
1268
1269#[derive(Clone, Debug, PartialEq, Serialize)]
1270pub struct UserActivitySummary {
1271    pub id: UserId,
1272    pub github_login: String,
1273    pub project_activity: Vec<(ProjectId, Duration)>,
1274}
1275
1276#[derive(Clone, Debug, PartialEq, Serialize)]
1277pub struct UserActivityPeriod {
1278    project_id: ProjectId,
1279    start: OffsetDateTime,
1280    end: OffsetDateTime,
1281    extensions: HashMap<String, usize>,
1282}
1283
1284id_type!(OrgId);
1285#[derive(FromRow)]
1286pub struct Org {
1287    pub id: OrgId,
1288    pub name: String,
1289    pub slug: String,
1290}
1291
1292id_type!(ChannelId);
1293#[derive(Clone, Debug, FromRow, Serialize)]
1294pub struct Channel {
1295    pub id: ChannelId,
1296    pub name: String,
1297    pub owner_id: i32,
1298    pub owner_is_user: bool,
1299}
1300
1301id_type!(MessageId);
1302#[derive(Clone, Debug, FromRow)]
1303pub struct ChannelMessage {
1304    pub id: MessageId,
1305    pub channel_id: ChannelId,
1306    pub sender_id: UserId,
1307    pub body: String,
1308    pub sent_at: OffsetDateTime,
1309    pub nonce: Uuid,
1310}
1311
1312#[derive(Clone, Debug, PartialEq, Eq)]
1313pub enum Contact {
1314    Accepted {
1315        user_id: UserId,
1316        should_notify: bool,
1317    },
1318    Outgoing {
1319        user_id: UserId,
1320    },
1321    Incoming {
1322        user_id: UserId,
1323        should_notify: bool,
1324    },
1325}
1326
1327impl Contact {
1328    pub fn user_id(&self) -> UserId {
1329        match self {
1330            Contact::Accepted { user_id, .. } => *user_id,
1331            Contact::Outgoing { user_id } => *user_id,
1332            Contact::Incoming { user_id, .. } => *user_id,
1333        }
1334    }
1335}
1336
1337#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1338pub struct IncomingContactRequest {
1339    pub requester_id: UserId,
1340    pub should_notify: bool,
1341}
1342
1343fn fuzzy_like_string(string: &str) -> String {
1344    let mut result = String::with_capacity(string.len() * 2 + 1);
1345    for c in string.chars() {
1346        if c.is_alphanumeric() {
1347            result.push('%');
1348            result.push(c);
1349        }
1350    }
1351    result.push('%');
1352    result
1353}
1354
1355#[cfg(test)]
1356pub mod tests {
1357    use super::*;
1358    use anyhow::anyhow;
1359    use collections::BTreeMap;
1360    use gpui::executor::Background;
1361    use lazy_static::lazy_static;
1362    use parking_lot::Mutex;
1363    use rand::prelude::*;
1364    use sqlx::{
1365        migrate::{MigrateDatabase, Migrator},
1366        Postgres,
1367    };
1368    use std::{path::Path, sync::Arc};
1369    use util::post_inc;
1370
1371    #[tokio::test(flavor = "multi_thread")]
1372    async fn test_get_users_by_ids() {
1373        for test_db in [
1374            TestDb::postgres().await,
1375            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1376        ] {
1377            let db = test_db.db();
1378
1379            let user = db.create_user("user", None, false).await.unwrap();
1380            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1381            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1382            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1383
1384            assert_eq!(
1385                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1386                    .await
1387                    .unwrap(),
1388                vec![
1389                    User {
1390                        id: user,
1391                        github_login: "user".to_string(),
1392                        admin: false,
1393                        ..Default::default()
1394                    },
1395                    User {
1396                        id: friend1,
1397                        github_login: "friend-1".to_string(),
1398                        admin: false,
1399                        ..Default::default()
1400                    },
1401                    User {
1402                        id: friend2,
1403                        github_login: "friend-2".to_string(),
1404                        admin: false,
1405                        ..Default::default()
1406                    },
1407                    User {
1408                        id: friend3,
1409                        github_login: "friend-3".to_string(),
1410                        admin: false,
1411                        ..Default::default()
1412                    }
1413                ]
1414            );
1415        }
1416    }
1417
1418    #[tokio::test(flavor = "multi_thread")]
1419    async fn test_create_users() {
1420        let db = TestDb::postgres().await;
1421        let db = db.db();
1422
1423        // Create the first batch of users, ensuring invite counts are assigned
1424        // correctly and the respective invite codes are unique.
1425        let user_ids_batch_1 = db
1426            .create_users(vec![
1427                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1428                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1429                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1430            ])
1431            .await
1432            .unwrap();
1433        assert_eq!(user_ids_batch_1.len(), 3);
1434
1435        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1436        assert_eq!(users.len(), 3);
1437        assert_eq!(users[0].github_login, "user1");
1438        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1439        assert_eq!(users[0].invite_count, 5);
1440        assert_eq!(users[1].github_login, "user2");
1441        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1442        assert_eq!(users[1].invite_count, 4);
1443        assert_eq!(users[2].github_login, "user3");
1444        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1445        assert_eq!(users[2].invite_count, 3);
1446
1447        let invite_code_1 = users[0].invite_code.clone().unwrap();
1448        let invite_code_2 = users[1].invite_code.clone().unwrap();
1449        let invite_code_3 = users[2].invite_code.clone().unwrap();
1450        assert_ne!(invite_code_1, invite_code_2);
1451        assert_ne!(invite_code_1, invite_code_3);
1452        assert_ne!(invite_code_2, invite_code_3);
1453
1454        // Create the second batch of users and include a user that is already in the database, ensuring
1455        // the invite count for the existing user is updated without changing their invite code.
1456        let user_ids_batch_2 = db
1457            .create_users(vec![
1458                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1459                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1460            ])
1461            .await
1462            .unwrap();
1463        assert_eq!(user_ids_batch_2.len(), 2);
1464        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1465
1466        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1467        assert_eq!(users.len(), 2);
1468        assert_eq!(users[0].github_login, "user2");
1469        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1470        assert_eq!(users[0].invite_count, 10);
1471        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1472        assert_eq!(users[1].github_login, "user4");
1473        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1474        assert_eq!(users[1].invite_count, 2);
1475
1476        let invite_code_4 = users[1].invite_code.clone().unwrap();
1477        assert_ne!(invite_code_4, invite_code_1);
1478        assert_ne!(invite_code_4, invite_code_2);
1479        assert_ne!(invite_code_4, invite_code_3);
1480    }
1481
1482    #[tokio::test(flavor = "multi_thread")]
1483    async fn test_worktree_extensions() {
1484        let test_db = TestDb::postgres().await;
1485        let db = test_db.db();
1486
1487        let user = db.create_user("user_1", None, false).await.unwrap();
1488        let project = db.register_project(user).await.unwrap();
1489
1490        db.update_worktree_extensions(project, 100, Default::default())
1491            .await
1492            .unwrap();
1493        db.update_worktree_extensions(
1494            project,
1495            100,
1496            [("rs".to_string(), 5), ("md".to_string(), 3)]
1497                .into_iter()
1498                .collect(),
1499        )
1500        .await
1501        .unwrap();
1502        db.update_worktree_extensions(
1503            project,
1504            100,
1505            [("rs".to_string(), 6), ("md".to_string(), 5)]
1506                .into_iter()
1507                .collect(),
1508        )
1509        .await
1510        .unwrap();
1511        db.update_worktree_extensions(
1512            project,
1513            101,
1514            [("ts".to_string(), 2), ("md".to_string(), 1)]
1515                .into_iter()
1516                .collect(),
1517        )
1518        .await
1519        .unwrap();
1520
1521        assert_eq!(
1522            db.get_project_extensions(project).await.unwrap(),
1523            [
1524                (
1525                    100,
1526                    [("rs".into(), 6), ("md".into(), 5),]
1527                        .into_iter()
1528                        .collect::<HashMap<_, _>>()
1529                ),
1530                (
1531                    101,
1532                    [("ts".into(), 2), ("md".into(), 1),]
1533                        .into_iter()
1534                        .collect::<HashMap<_, _>>()
1535                )
1536            ]
1537            .into_iter()
1538            .collect()
1539        );
1540    }
1541
1542    #[tokio::test(flavor = "multi_thread")]
1543    async fn test_project_activity() {
1544        let test_db = TestDb::postgres().await;
1545        let db = test_db.db();
1546
1547        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1548        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1549        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1550        let project_1 = db.register_project(user_1).await.unwrap();
1551        db.update_worktree_extensions(
1552            project_1,
1553            1,
1554            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1555        )
1556        .await
1557        .unwrap();
1558        let project_2 = db.register_project(user_2).await.unwrap();
1559        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1560
1561        // User 2 opens a project
1562        let t1 = t0 + Duration::from_secs(10);
1563        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1564            .await
1565            .unwrap();
1566
1567        let t2 = t1 + Duration::from_secs(10);
1568        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1569            .await
1570            .unwrap();
1571
1572        // User 1 joins the project
1573        let t3 = t2 + Duration::from_secs(10);
1574        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1575            .await
1576            .unwrap();
1577
1578        // User 1 opens another project
1579        let t4 = t3 + Duration::from_secs(10);
1580        db.record_user_activity(
1581            t3..t4,
1582            &[
1583                (user_2, project_2),
1584                (user_1, project_2),
1585                (user_1, project_1),
1586            ],
1587        )
1588        .await
1589        .unwrap();
1590
1591        // User 3 joins that project
1592        let t5 = t4 + Duration::from_secs(10);
1593        db.record_user_activity(
1594            t4..t5,
1595            &[
1596                (user_2, project_2),
1597                (user_1, project_2),
1598                (user_1, project_1),
1599                (user_3, project_1),
1600            ],
1601        )
1602        .await
1603        .unwrap();
1604
1605        // User 2 leaves
1606        let t6 = t5 + Duration::from_secs(5);
1607        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1608            .await
1609            .unwrap();
1610
1611        let t7 = t6 + Duration::from_secs(60);
1612        let t8 = t7 + Duration::from_secs(10);
1613        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1614            .await
1615            .unwrap();
1616
1617        assert_eq!(
1618            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1619            &[
1620                UserActivitySummary {
1621                    id: user_1,
1622                    github_login: "user_1".to_string(),
1623                    project_activity: vec![
1624                        (project_2, Duration::from_secs(30)),
1625                        (project_1, Duration::from_secs(25))
1626                    ]
1627                },
1628                UserActivitySummary {
1629                    id: user_2,
1630                    github_login: "user_2".to_string(),
1631                    project_activity: vec![(project_2, Duration::from_secs(50))]
1632                },
1633                UserActivitySummary {
1634                    id: user_3,
1635                    github_login: "user_3".to_string(),
1636                    project_activity: vec![(project_1, Duration::from_secs(15))]
1637                },
1638            ]
1639        );
1640        assert_eq!(
1641            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1642            &[
1643                UserActivityPeriod {
1644                    project_id: project_1,
1645                    start: t3,
1646                    end: t6,
1647                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1648                },
1649                UserActivityPeriod {
1650                    project_id: project_2,
1651                    start: t3,
1652                    end: t5,
1653                    extensions: Default::default(),
1654                },
1655            ]
1656        );
1657        assert_eq!(
1658            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1659            &[
1660                UserActivityPeriod {
1661                    project_id: project_2,
1662                    start: t2,
1663                    end: t5,
1664                    extensions: Default::default(),
1665                },
1666                UserActivityPeriod {
1667                    project_id: project_1,
1668                    start: t3,
1669                    end: t6,
1670                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1671                },
1672                UserActivityPeriod {
1673                    project_id: project_1,
1674                    start: t7,
1675                    end: t8,
1676                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1677                },
1678            ]
1679        );
1680    }
1681
1682    #[tokio::test(flavor = "multi_thread")]
1683    async fn test_recent_channel_messages() {
1684        for test_db in [
1685            TestDb::postgres().await,
1686            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1687        ] {
1688            let db = test_db.db();
1689            let user = db.create_user("user", None, false).await.unwrap();
1690            let org = db.create_org("org", "org").await.unwrap();
1691            let channel = db.create_org_channel(org, "channel").await.unwrap();
1692            for i in 0..10 {
1693                db.create_channel_message(
1694                    channel,
1695                    user,
1696                    &i.to_string(),
1697                    OffsetDateTime::now_utc(),
1698                    i,
1699                )
1700                .await
1701                .unwrap();
1702            }
1703
1704            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1705            assert_eq!(
1706                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1707                ["5", "6", "7", "8", "9"]
1708            );
1709
1710            let prev_messages = db
1711                .get_channel_messages(channel, 4, Some(messages[0].id))
1712                .await
1713                .unwrap();
1714            assert_eq!(
1715                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1716                ["1", "2", "3", "4"]
1717            );
1718        }
1719    }
1720
1721    #[tokio::test(flavor = "multi_thread")]
1722    async fn test_channel_message_nonces() {
1723        for test_db in [
1724            TestDb::postgres().await,
1725            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1726        ] {
1727            let db = test_db.db();
1728            let user = db.create_user("user", None, false).await.unwrap();
1729            let org = db.create_org("org", "org").await.unwrap();
1730            let channel = db.create_org_channel(org, "channel").await.unwrap();
1731
1732            let msg1_id = db
1733                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1734                .await
1735                .unwrap();
1736            let msg2_id = db
1737                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1738                .await
1739                .unwrap();
1740            let msg3_id = db
1741                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1742                .await
1743                .unwrap();
1744            let msg4_id = db
1745                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1746                .await
1747                .unwrap();
1748
1749            assert_ne!(msg1_id, msg2_id);
1750            assert_eq!(msg1_id, msg3_id);
1751            assert_eq!(msg2_id, msg4_id);
1752        }
1753    }
1754
1755    #[tokio::test(flavor = "multi_thread")]
1756    async fn test_create_access_tokens() {
1757        let test_db = TestDb::postgres().await;
1758        let db = test_db.db();
1759        let user = db.create_user("the-user", None, false).await.unwrap();
1760
1761        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1762        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1763        assert_eq!(
1764            db.get_access_token_hashes(user).await.unwrap(),
1765            &["h2".to_string(), "h1".to_string()]
1766        );
1767
1768        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1769        assert_eq!(
1770            db.get_access_token_hashes(user).await.unwrap(),
1771            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1772        );
1773
1774        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1775        assert_eq!(
1776            db.get_access_token_hashes(user).await.unwrap(),
1777            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1778        );
1779
1780        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1781        assert_eq!(
1782            db.get_access_token_hashes(user).await.unwrap(),
1783            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1784        );
1785    }
1786
1787    #[test]
1788    fn test_fuzzy_like_string() {
1789        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1790        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1791        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1792    }
1793
1794    #[tokio::test(flavor = "multi_thread")]
1795    async fn test_fuzzy_search_users() {
1796        let test_db = TestDb::postgres().await;
1797        let db = test_db.db();
1798        for github_login in [
1799            "California",
1800            "colorado",
1801            "oregon",
1802            "washington",
1803            "florida",
1804            "delaware",
1805            "rhode-island",
1806        ] {
1807            db.create_user(github_login, None, false).await.unwrap();
1808        }
1809
1810        assert_eq!(
1811            fuzzy_search_user_names(db, "clr").await,
1812            &["colorado", "California"]
1813        );
1814        assert_eq!(
1815            fuzzy_search_user_names(db, "ro").await,
1816            &["rhode-island", "colorado", "oregon"],
1817        );
1818
1819        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1820            db.fuzzy_search_users(query, 10)
1821                .await
1822                .unwrap()
1823                .into_iter()
1824                .map(|user| user.github_login)
1825                .collect::<Vec<_>>()
1826        }
1827    }
1828
1829    #[tokio::test(flavor = "multi_thread")]
1830    async fn test_add_contacts() {
1831        for test_db in [
1832            TestDb::postgres().await,
1833            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1834        ] {
1835            let db = test_db.db();
1836
1837            let user_1 = db.create_user("user1", None, false).await.unwrap();
1838            let user_2 = db.create_user("user2", None, false).await.unwrap();
1839            let user_3 = db.create_user("user3", None, false).await.unwrap();
1840
1841            // User starts with no contacts
1842            assert_eq!(
1843                db.get_contacts(user_1).await.unwrap(),
1844                vec![Contact::Accepted {
1845                    user_id: user_1,
1846                    should_notify: false
1847                }],
1848            );
1849
1850            // User requests a contact. Both users see the pending request.
1851            db.send_contact_request(user_1, user_2).await.unwrap();
1852            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1853            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1854            assert_eq!(
1855                db.get_contacts(user_1).await.unwrap(),
1856                &[
1857                    Contact::Accepted {
1858                        user_id: user_1,
1859                        should_notify: false
1860                    },
1861                    Contact::Outgoing { user_id: user_2 }
1862                ],
1863            );
1864            assert_eq!(
1865                db.get_contacts(user_2).await.unwrap(),
1866                &[
1867                    Contact::Incoming {
1868                        user_id: user_1,
1869                        should_notify: true
1870                    },
1871                    Contact::Accepted {
1872                        user_id: user_2,
1873                        should_notify: false
1874                    },
1875                ]
1876            );
1877
1878            // User 2 dismisses the contact request notification without accepting or rejecting.
1879            // We shouldn't notify them again.
1880            db.dismiss_contact_notification(user_1, user_2)
1881                .await
1882                .unwrap_err();
1883            db.dismiss_contact_notification(user_2, user_1)
1884                .await
1885                .unwrap();
1886            assert_eq!(
1887                db.get_contacts(user_2).await.unwrap(),
1888                &[
1889                    Contact::Incoming {
1890                        user_id: user_1,
1891                        should_notify: false
1892                    },
1893                    Contact::Accepted {
1894                        user_id: user_2,
1895                        should_notify: false
1896                    },
1897                ]
1898            );
1899
1900            // User can't accept their own contact request
1901            db.respond_to_contact_request(user_1, user_2, true)
1902                .await
1903                .unwrap_err();
1904
1905            // User accepts a contact request. Both users see the contact.
1906            db.respond_to_contact_request(user_2, user_1, true)
1907                .await
1908                .unwrap();
1909            assert_eq!(
1910                db.get_contacts(user_1).await.unwrap(),
1911                &[
1912                    Contact::Accepted {
1913                        user_id: user_1,
1914                        should_notify: false
1915                    },
1916                    Contact::Accepted {
1917                        user_id: user_2,
1918                        should_notify: true
1919                    }
1920                ],
1921            );
1922            assert!(db.has_contact(user_1, user_2).await.unwrap());
1923            assert!(db.has_contact(user_2, user_1).await.unwrap());
1924            assert_eq!(
1925                db.get_contacts(user_2).await.unwrap(),
1926                &[
1927                    Contact::Accepted {
1928                        user_id: user_1,
1929                        should_notify: false,
1930                    },
1931                    Contact::Accepted {
1932                        user_id: user_2,
1933                        should_notify: false,
1934                    },
1935                ]
1936            );
1937
1938            // Users cannot re-request existing contacts.
1939            db.send_contact_request(user_1, user_2).await.unwrap_err();
1940            db.send_contact_request(user_2, user_1).await.unwrap_err();
1941
1942            // Users can't dismiss notifications of them accepting other users' requests.
1943            db.dismiss_contact_notification(user_2, user_1)
1944                .await
1945                .unwrap_err();
1946            assert_eq!(
1947                db.get_contacts(user_1).await.unwrap(),
1948                &[
1949                    Contact::Accepted {
1950                        user_id: user_1,
1951                        should_notify: false
1952                    },
1953                    Contact::Accepted {
1954                        user_id: user_2,
1955                        should_notify: true,
1956                    },
1957                ]
1958            );
1959
1960            // Users can dismiss notifications of other users accepting their requests.
1961            db.dismiss_contact_notification(user_1, user_2)
1962                .await
1963                .unwrap();
1964            assert_eq!(
1965                db.get_contacts(user_1).await.unwrap(),
1966                &[
1967                    Contact::Accepted {
1968                        user_id: user_1,
1969                        should_notify: false
1970                    },
1971                    Contact::Accepted {
1972                        user_id: user_2,
1973                        should_notify: false,
1974                    },
1975                ]
1976            );
1977
1978            // Users send each other concurrent contact requests and
1979            // see that they are immediately accepted.
1980            db.send_contact_request(user_1, user_3).await.unwrap();
1981            db.send_contact_request(user_3, user_1).await.unwrap();
1982            assert_eq!(
1983                db.get_contacts(user_1).await.unwrap(),
1984                &[
1985                    Contact::Accepted {
1986                        user_id: user_1,
1987                        should_notify: false
1988                    },
1989                    Contact::Accepted {
1990                        user_id: user_2,
1991                        should_notify: false,
1992                    },
1993                    Contact::Accepted {
1994                        user_id: user_3,
1995                        should_notify: false
1996                    },
1997                ]
1998            );
1999            assert_eq!(
2000                db.get_contacts(user_3).await.unwrap(),
2001                &[
2002                    Contact::Accepted {
2003                        user_id: user_1,
2004                        should_notify: false
2005                    },
2006                    Contact::Accepted {
2007                        user_id: user_3,
2008                        should_notify: false
2009                    }
2010                ],
2011            );
2012
2013            // User declines a contact request. Both users see that it is gone.
2014            db.send_contact_request(user_2, user_3).await.unwrap();
2015            db.respond_to_contact_request(user_3, user_2, false)
2016                .await
2017                .unwrap();
2018            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2019            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2020            assert_eq!(
2021                db.get_contacts(user_2).await.unwrap(),
2022                &[
2023                    Contact::Accepted {
2024                        user_id: user_1,
2025                        should_notify: false
2026                    },
2027                    Contact::Accepted {
2028                        user_id: user_2,
2029                        should_notify: false
2030                    }
2031                ]
2032            );
2033            assert_eq!(
2034                db.get_contacts(user_3).await.unwrap(),
2035                &[
2036                    Contact::Accepted {
2037                        user_id: user_1,
2038                        should_notify: false
2039                    },
2040                    Contact::Accepted {
2041                        user_id: user_3,
2042                        should_notify: false
2043                    }
2044                ],
2045            );
2046        }
2047    }
2048
2049    #[tokio::test(flavor = "multi_thread")]
2050    async fn test_invite_codes() {
2051        let postgres = TestDb::postgres().await;
2052        let db = postgres.db();
2053        let user1 = db.create_user("user-1", None, false).await.unwrap();
2054
2055        // Initially, user 1 has no invite code
2056        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2057
2058        // Setting invite count to 0 when no code is assigned does not assign a new code
2059        db.set_invite_count(user1, 0).await.unwrap();
2060        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2061
2062        // User 1 creates an invite code that can be used twice.
2063        db.set_invite_count(user1, 2).await.unwrap();
2064        let (invite_code, invite_count) =
2065            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2066        assert_eq!(invite_count, 2);
2067
2068        // User 2 redeems the invite code and becomes a contact of user 1.
2069        let user2 = db
2070            .redeem_invite_code(&invite_code, "user-2", None)
2071            .await
2072            .unwrap();
2073        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2074        assert_eq!(invite_count, 1);
2075        assert_eq!(
2076            db.get_contacts(user1).await.unwrap(),
2077            [
2078                Contact::Accepted {
2079                    user_id: user1,
2080                    should_notify: false
2081                },
2082                Contact::Accepted {
2083                    user_id: user2,
2084                    should_notify: true
2085                }
2086            ]
2087        );
2088        assert_eq!(
2089            db.get_contacts(user2).await.unwrap(),
2090            [
2091                Contact::Accepted {
2092                    user_id: user1,
2093                    should_notify: false
2094                },
2095                Contact::Accepted {
2096                    user_id: user2,
2097                    should_notify: false
2098                }
2099            ]
2100        );
2101
2102        // User 3 redeems the invite code and becomes a contact of user 1.
2103        let user3 = db
2104            .redeem_invite_code(&invite_code, "user-3", None)
2105            .await
2106            .unwrap();
2107        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2108        assert_eq!(invite_count, 0);
2109        assert_eq!(
2110            db.get_contacts(user1).await.unwrap(),
2111            [
2112                Contact::Accepted {
2113                    user_id: user1,
2114                    should_notify: false
2115                },
2116                Contact::Accepted {
2117                    user_id: user2,
2118                    should_notify: true
2119                },
2120                Contact::Accepted {
2121                    user_id: user3,
2122                    should_notify: true
2123                }
2124            ]
2125        );
2126        assert_eq!(
2127            db.get_contacts(user3).await.unwrap(),
2128            [
2129                Contact::Accepted {
2130                    user_id: user1,
2131                    should_notify: false
2132                },
2133                Contact::Accepted {
2134                    user_id: user3,
2135                    should_notify: false
2136                },
2137            ]
2138        );
2139
2140        // Trying to reedem the code for the third time results in an error.
2141        db.redeem_invite_code(&invite_code, "user-4", None)
2142            .await
2143            .unwrap_err();
2144
2145        // Invite count can be updated after the code has been created.
2146        db.set_invite_count(user1, 2).await.unwrap();
2147        let (latest_code, invite_count) =
2148            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2149        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2150        assert_eq!(invite_count, 2);
2151
2152        // User 4 can now redeem the invite code and becomes a contact of user 1.
2153        let user4 = db
2154            .redeem_invite_code(&invite_code, "user-4", None)
2155            .await
2156            .unwrap();
2157        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2158        assert_eq!(invite_count, 1);
2159        assert_eq!(
2160            db.get_contacts(user1).await.unwrap(),
2161            [
2162                Contact::Accepted {
2163                    user_id: user1,
2164                    should_notify: false
2165                },
2166                Contact::Accepted {
2167                    user_id: user2,
2168                    should_notify: true
2169                },
2170                Contact::Accepted {
2171                    user_id: user3,
2172                    should_notify: true
2173                },
2174                Contact::Accepted {
2175                    user_id: user4,
2176                    should_notify: true
2177                }
2178            ]
2179        );
2180        assert_eq!(
2181            db.get_contacts(user4).await.unwrap(),
2182            [
2183                Contact::Accepted {
2184                    user_id: user1,
2185                    should_notify: false
2186                },
2187                Contact::Accepted {
2188                    user_id: user4,
2189                    should_notify: false
2190                },
2191            ]
2192        );
2193
2194        // An existing user cannot redeem invite codes.
2195        db.redeem_invite_code(&invite_code, "user-2", None)
2196            .await
2197            .unwrap_err();
2198        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2199        assert_eq!(invite_count, 1);
2200    }
2201
2202    pub struct TestDb {
2203        pub db: Option<Arc<dyn Db>>,
2204        pub url: String,
2205    }
2206
2207    impl TestDb {
2208        pub async fn postgres() -> Self {
2209            lazy_static! {
2210                static ref LOCK: Mutex<()> = Mutex::new(());
2211            }
2212
2213            let _guard = LOCK.lock();
2214            let mut rng = StdRng::from_entropy();
2215            let name = format!("zed-test-{}", rng.gen::<u128>());
2216            let url = format!("postgres://postgres@localhost/{}", name);
2217            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2218            Postgres::create_database(&url)
2219                .await
2220                .expect("failed to create test db");
2221            let db = PostgresDb::new(&url, 5).await.unwrap();
2222            let migrator = Migrator::new(migrations_path).await.unwrap();
2223            migrator.run(&db.pool).await.unwrap();
2224            Self {
2225                db: Some(Arc::new(db)),
2226                url,
2227            }
2228        }
2229
2230        pub fn fake(background: Arc<Background>) -> Self {
2231            Self {
2232                db: Some(Arc::new(FakeDb::new(background))),
2233                url: Default::default(),
2234            }
2235        }
2236
2237        pub fn db(&self) -> &Arc<dyn Db> {
2238            self.db.as_ref().unwrap()
2239        }
2240    }
2241
2242    impl Drop for TestDb {
2243        fn drop(&mut self) {
2244            if let Some(db) = self.db.take() {
2245                futures::executor::block_on(db.teardown(&self.url));
2246            }
2247        }
2248    }
2249
2250    pub struct FakeDb {
2251        background: Arc<Background>,
2252        pub users: Mutex<BTreeMap<UserId, User>>,
2253        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2254        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2255        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2256        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2257        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2258        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2259        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2260        pub contacts: Mutex<Vec<FakeContact>>,
2261        next_channel_message_id: Mutex<i32>,
2262        next_user_id: Mutex<i32>,
2263        next_org_id: Mutex<i32>,
2264        next_channel_id: Mutex<i32>,
2265        next_project_id: Mutex<i32>,
2266    }
2267
2268    #[derive(Debug)]
2269    pub struct FakeContact {
2270        pub requester_id: UserId,
2271        pub responder_id: UserId,
2272        pub accepted: bool,
2273        pub should_notify: bool,
2274    }
2275
2276    impl FakeDb {
2277        pub fn new(background: Arc<Background>) -> Self {
2278            Self {
2279                background,
2280                users: Default::default(),
2281                next_user_id: Mutex::new(1),
2282                projects: Default::default(),
2283                worktree_extensions: Default::default(),
2284                next_project_id: Mutex::new(1),
2285                orgs: Default::default(),
2286                next_org_id: Mutex::new(1),
2287                org_memberships: Default::default(),
2288                channels: Default::default(),
2289                next_channel_id: Mutex::new(1),
2290                channel_memberships: Default::default(),
2291                channel_messages: Default::default(),
2292                next_channel_message_id: Mutex::new(1),
2293                contacts: Default::default(),
2294            }
2295        }
2296    }
2297
2298    #[async_trait]
2299    impl Db for FakeDb {
2300        async fn create_user(
2301            &self,
2302            github_login: &str,
2303            email_address: Option<&str>,
2304            admin: bool,
2305        ) -> Result<UserId> {
2306            self.background.simulate_random_delay().await;
2307
2308            let mut users = self.users.lock();
2309            if let Some(user) = users
2310                .values()
2311                .find(|user| user.github_login == github_login)
2312            {
2313                Ok(user.id)
2314            } else {
2315                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2316                users.insert(
2317                    user_id,
2318                    User {
2319                        id: user_id,
2320                        github_login: github_login.to_string(),
2321                        email_address: email_address.map(str::to_string),
2322                        admin,
2323                        invite_code: None,
2324                        invite_count: 0,
2325                        connected_once: false,
2326                    },
2327                );
2328                Ok(user_id)
2329            }
2330        }
2331
2332        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2333            unimplemented!()
2334        }
2335
2336        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2337            unimplemented!()
2338        }
2339
2340        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2341            unimplemented!()
2342        }
2343
2344        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2345            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2346        }
2347
2348        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2349            self.background.simulate_random_delay().await;
2350            let users = self.users.lock();
2351            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2352        }
2353
2354        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2355            unimplemented!()
2356        }
2357
2358        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2359            Ok(self
2360                .users
2361                .lock()
2362                .values()
2363                .find(|user| user.github_login == github_login)
2364                .cloned())
2365        }
2366
2367        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2368            unimplemented!()
2369        }
2370
2371        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2372            self.background.simulate_random_delay().await;
2373            let mut users = self.users.lock();
2374            let mut user = users
2375                .get_mut(&id)
2376                .ok_or_else(|| anyhow!("user not found"))?;
2377            user.connected_once = connected_once;
2378            Ok(())
2379        }
2380
2381        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2382            unimplemented!()
2383        }
2384
2385        // invite codes
2386
2387        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2388            unimplemented!()
2389        }
2390
2391        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2392            Ok(None)
2393        }
2394
2395        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2396            unimplemented!()
2397        }
2398
2399        async fn redeem_invite_code(
2400            &self,
2401            _code: &str,
2402            _login: &str,
2403            _email_address: Option<&str>,
2404        ) -> Result<UserId> {
2405            unimplemented!()
2406        }
2407
2408        // projects
2409
2410        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2411            self.background.simulate_random_delay().await;
2412            if !self.users.lock().contains_key(&host_user_id) {
2413                Err(anyhow!("no such user"))?;
2414            }
2415
2416            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2417            self.projects.lock().insert(
2418                project_id,
2419                Project {
2420                    id: project_id,
2421                    host_user_id,
2422                    unregistered: false,
2423                },
2424            );
2425            Ok(project_id)
2426        }
2427
2428        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2429            self.projects
2430                .lock()
2431                .get_mut(&project_id)
2432                .ok_or_else(|| anyhow!("no such project"))?
2433                .unregistered = true;
2434            Ok(())
2435        }
2436
2437        async fn update_worktree_extensions(
2438            &self,
2439            project_id: ProjectId,
2440            worktree_id: u64,
2441            extensions: HashMap<String, usize>,
2442        ) -> Result<()> {
2443            self.background.simulate_random_delay().await;
2444            if !self.projects.lock().contains_key(&project_id) {
2445                Err(anyhow!("no such project"))?;
2446            }
2447
2448            for (extension, count) in extensions {
2449                self.worktree_extensions
2450                    .lock()
2451                    .insert((project_id, worktree_id, extension), count);
2452            }
2453
2454            Ok(())
2455        }
2456
2457        async fn get_project_extensions(
2458            &self,
2459            _project_id: ProjectId,
2460        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2461            unimplemented!()
2462        }
2463
2464        async fn record_user_activity(
2465            &self,
2466            _time_period: Range<OffsetDateTime>,
2467            _active_projects: &[(UserId, ProjectId)],
2468        ) -> Result<()> {
2469            unimplemented!()
2470        }
2471
2472        async fn get_top_users_activity_summary(
2473            &self,
2474            _time_period: Range<OffsetDateTime>,
2475            _limit: usize,
2476        ) -> Result<Vec<UserActivitySummary>> {
2477            unimplemented!()
2478        }
2479
2480        async fn get_user_activity_timeline(
2481            &self,
2482            _time_period: Range<OffsetDateTime>,
2483            _user_id: UserId,
2484        ) -> Result<Vec<UserActivityPeriod>> {
2485            unimplemented!()
2486        }
2487
2488        // contacts
2489
2490        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2491            self.background.simulate_random_delay().await;
2492            let mut contacts = vec![Contact::Accepted {
2493                user_id: id,
2494                should_notify: false,
2495            }];
2496
2497            for contact in self.contacts.lock().iter() {
2498                if contact.requester_id == id {
2499                    if contact.accepted {
2500                        contacts.push(Contact::Accepted {
2501                            user_id: contact.responder_id,
2502                            should_notify: contact.should_notify,
2503                        });
2504                    } else {
2505                        contacts.push(Contact::Outgoing {
2506                            user_id: contact.responder_id,
2507                        });
2508                    }
2509                } else if contact.responder_id == id {
2510                    if contact.accepted {
2511                        contacts.push(Contact::Accepted {
2512                            user_id: contact.requester_id,
2513                            should_notify: false,
2514                        });
2515                    } else {
2516                        contacts.push(Contact::Incoming {
2517                            user_id: contact.requester_id,
2518                            should_notify: contact.should_notify,
2519                        });
2520                    }
2521                }
2522            }
2523
2524            contacts.sort_unstable_by_key(|contact| contact.user_id());
2525            Ok(contacts)
2526        }
2527
2528        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2529            self.background.simulate_random_delay().await;
2530            Ok(self.contacts.lock().iter().any(|contact| {
2531                contact.accepted
2532                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2533                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2534            }))
2535        }
2536
2537        async fn send_contact_request(
2538            &self,
2539            requester_id: UserId,
2540            responder_id: UserId,
2541        ) -> Result<()> {
2542            let mut contacts = self.contacts.lock();
2543            for contact in contacts.iter_mut() {
2544                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2545                    if contact.accepted {
2546                        Err(anyhow!("contact already exists"))?;
2547                    } else {
2548                        Err(anyhow!("contact already requested"))?;
2549                    }
2550                }
2551                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2552                    if contact.accepted {
2553                        Err(anyhow!("contact already exists"))?;
2554                    } else {
2555                        contact.accepted = true;
2556                        contact.should_notify = false;
2557                        return Ok(());
2558                    }
2559                }
2560            }
2561            contacts.push(FakeContact {
2562                requester_id,
2563                responder_id,
2564                accepted: false,
2565                should_notify: true,
2566            });
2567            Ok(())
2568        }
2569
2570        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2571            self.contacts.lock().retain(|contact| {
2572                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2573            });
2574            Ok(())
2575        }
2576
2577        async fn dismiss_contact_notification(
2578            &self,
2579            user_id: UserId,
2580            contact_user_id: UserId,
2581        ) -> Result<()> {
2582            let mut contacts = self.contacts.lock();
2583            for contact in contacts.iter_mut() {
2584                if contact.requester_id == contact_user_id
2585                    && contact.responder_id == user_id
2586                    && !contact.accepted
2587                {
2588                    contact.should_notify = false;
2589                    return Ok(());
2590                }
2591                if contact.requester_id == user_id
2592                    && contact.responder_id == contact_user_id
2593                    && contact.accepted
2594                {
2595                    contact.should_notify = false;
2596                    return Ok(());
2597                }
2598            }
2599            Err(anyhow!("no such notification"))?
2600        }
2601
2602        async fn respond_to_contact_request(
2603            &self,
2604            responder_id: UserId,
2605            requester_id: UserId,
2606            accept: bool,
2607        ) -> Result<()> {
2608            let mut contacts = self.contacts.lock();
2609            for (ix, contact) in contacts.iter_mut().enumerate() {
2610                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2611                    if contact.accepted {
2612                        Err(anyhow!("contact already confirmed"))?;
2613                    }
2614                    if accept {
2615                        contact.accepted = true;
2616                        contact.should_notify = true;
2617                    } else {
2618                        contacts.remove(ix);
2619                    }
2620                    return Ok(());
2621                }
2622            }
2623            Err(anyhow!("no such contact request"))?
2624        }
2625
2626        async fn create_access_token_hash(
2627            &self,
2628            _user_id: UserId,
2629            _access_token_hash: &str,
2630            _max_access_token_count: usize,
2631        ) -> Result<()> {
2632            unimplemented!()
2633        }
2634
2635        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2636            unimplemented!()
2637        }
2638
2639        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2640            unimplemented!()
2641        }
2642
2643        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2644            self.background.simulate_random_delay().await;
2645            let mut orgs = self.orgs.lock();
2646            if orgs.values().any(|org| org.slug == slug) {
2647                Err(anyhow!("org already exists"))?
2648            } else {
2649                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2650                orgs.insert(
2651                    org_id,
2652                    Org {
2653                        id: org_id,
2654                        name: name.to_string(),
2655                        slug: slug.to_string(),
2656                    },
2657                );
2658                Ok(org_id)
2659            }
2660        }
2661
2662        async fn add_org_member(
2663            &self,
2664            org_id: OrgId,
2665            user_id: UserId,
2666            is_admin: bool,
2667        ) -> Result<()> {
2668            self.background.simulate_random_delay().await;
2669            if !self.orgs.lock().contains_key(&org_id) {
2670                Err(anyhow!("org does not exist"))?;
2671            }
2672            if !self.users.lock().contains_key(&user_id) {
2673                Err(anyhow!("user does not exist"))?;
2674            }
2675
2676            self.org_memberships
2677                .lock()
2678                .entry((org_id, user_id))
2679                .or_insert(is_admin);
2680            Ok(())
2681        }
2682
2683        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2684            self.background.simulate_random_delay().await;
2685            if !self.orgs.lock().contains_key(&org_id) {
2686                Err(anyhow!("org does not exist"))?;
2687            }
2688
2689            let mut channels = self.channels.lock();
2690            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2691            channels.insert(
2692                channel_id,
2693                Channel {
2694                    id: channel_id,
2695                    name: name.to_string(),
2696                    owner_id: org_id.0,
2697                    owner_is_user: false,
2698                },
2699            );
2700            Ok(channel_id)
2701        }
2702
2703        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2704            self.background.simulate_random_delay().await;
2705            Ok(self
2706                .channels
2707                .lock()
2708                .values()
2709                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2710                .cloned()
2711                .collect())
2712        }
2713
2714        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2715            self.background.simulate_random_delay().await;
2716            let channels = self.channels.lock();
2717            let memberships = self.channel_memberships.lock();
2718            Ok(channels
2719                .values()
2720                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2721                .cloned()
2722                .collect())
2723        }
2724
2725        async fn can_user_access_channel(
2726            &self,
2727            user_id: UserId,
2728            channel_id: ChannelId,
2729        ) -> Result<bool> {
2730            self.background.simulate_random_delay().await;
2731            Ok(self
2732                .channel_memberships
2733                .lock()
2734                .contains_key(&(channel_id, user_id)))
2735        }
2736
2737        async fn add_channel_member(
2738            &self,
2739            channel_id: ChannelId,
2740            user_id: UserId,
2741            is_admin: bool,
2742        ) -> Result<()> {
2743            self.background.simulate_random_delay().await;
2744            if !self.channels.lock().contains_key(&channel_id) {
2745                Err(anyhow!("channel does not exist"))?;
2746            }
2747            if !self.users.lock().contains_key(&user_id) {
2748                Err(anyhow!("user does not exist"))?;
2749            }
2750
2751            self.channel_memberships
2752                .lock()
2753                .entry((channel_id, user_id))
2754                .or_insert(is_admin);
2755            Ok(())
2756        }
2757
2758        async fn create_channel_message(
2759            &self,
2760            channel_id: ChannelId,
2761            sender_id: UserId,
2762            body: &str,
2763            timestamp: OffsetDateTime,
2764            nonce: u128,
2765        ) -> Result<MessageId> {
2766            self.background.simulate_random_delay().await;
2767            if !self.channels.lock().contains_key(&channel_id) {
2768                Err(anyhow!("channel does not exist"))?;
2769            }
2770            if !self.users.lock().contains_key(&sender_id) {
2771                Err(anyhow!("user does not exist"))?;
2772            }
2773
2774            let mut messages = self.channel_messages.lock();
2775            if let Some(message) = messages
2776                .values()
2777                .find(|message| message.nonce.as_u128() == nonce)
2778            {
2779                Ok(message.id)
2780            } else {
2781                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2782                messages.insert(
2783                    message_id,
2784                    ChannelMessage {
2785                        id: message_id,
2786                        channel_id,
2787                        sender_id,
2788                        body: body.to_string(),
2789                        sent_at: timestamp,
2790                        nonce: Uuid::from_u128(nonce),
2791                    },
2792                );
2793                Ok(message_id)
2794            }
2795        }
2796
2797        async fn get_channel_messages(
2798            &self,
2799            channel_id: ChannelId,
2800            count: usize,
2801            before_id: Option<MessageId>,
2802        ) -> Result<Vec<ChannelMessage>> {
2803            let mut messages = self
2804                .channel_messages
2805                .lock()
2806                .values()
2807                .rev()
2808                .filter(|message| {
2809                    message.channel_id == channel_id
2810                        && message.id < before_id.unwrap_or(MessageId::MAX)
2811                })
2812                .take(count)
2813                .cloned()
2814                .collect::<Vec<_>>();
2815            messages.sort_unstable_by_key(|message| message.id);
2816            Ok(messages)
2817        }
2818
2819        async fn teardown(&self, _: &str) {}
2820
2821        #[cfg(test)]
2822        fn as_fake(&self) -> Option<&FakeDb> {
2823            Some(self)
2824        }
2825    }
2826}