db.rs

   1use std::{cmp, ops::Range, time::Duration};
   2
   3use crate::{Error, Result};
   4use anyhow::{anyhow, Context};
   5use async_trait::async_trait;
   6use axum::http::StatusCode;
   7use collections::HashMap;
   8use futures::StreamExt;
   9use nanoid::nanoid;
  10use serde::{Deserialize, Serialize};
  11pub use sqlx::postgres::PgPoolOptions as DbOptions;
  12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
  13use time::{OffsetDateTime, PrimitiveDateTime};
  14
  15#[async_trait]
  16pub trait Db: Send + Sync {
  17    async fn create_user(
  18        &self,
  19        github_login: &str,
  20        email_address: Option<&str>,
  21        admin: bool,
  22    ) -> Result<UserId>;
  23    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
  24    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
  25    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
  26    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
  27    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
  28    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
  29    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
  30    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
  31    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
  32    async fn destroy_user(&self, id: UserId) -> Result<()>;
  33
  34    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
  35    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
  36    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
  37    async fn redeem_invite_code(
  38        &self,
  39        code: &str,
  40        login: &str,
  41        email_address: Option<&str>,
  42    ) -> Result<UserId>;
  43
  44    /// Registers a new project for the given user.
  45    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
  46
  47    /// Unregisters a project for the given project id.
  48    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
  49
  50    /// Update file counts by extension for the given project and worktree.
  51    async fn update_worktree_extensions(
  52        &self,
  53        project_id: ProjectId,
  54        worktree_id: u64,
  55        extensions: HashMap<String, u32>,
  56    ) -> Result<()>;
  57
  58    /// Get the file counts on the given project keyed by their worktree and extension.
  59    async fn get_project_extensions(
  60        &self,
  61        project_id: ProjectId,
  62    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
  63
  64    /// Record which users have been active in which projects during
  65    /// a given period of time.
  66    async fn record_user_activity(
  67        &self,
  68        time_period: Range<OffsetDateTime>,
  69        active_projects: &[(UserId, ProjectId)],
  70    ) -> Result<()>;
  71
  72    /// Get the users that have been most active during the given time period,
  73    /// along with the amount of time they have been active in each project.
  74    async fn get_top_users_activity_summary(
  75        &self,
  76        time_period: Range<OffsetDateTime>,
  77        max_user_count: usize,
  78    ) -> Result<Vec<UserActivitySummary>>;
  79
  80    /// Get the project activity for the given user and time period.
  81    async fn get_user_activity_timeline(
  82        &self,
  83        time_period: Range<OffsetDateTime>,
  84        user_id: UserId,
  85    ) -> Result<Vec<UserActivityPeriod>>;
  86
  87    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
  88    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
  89    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  90    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
  91    async fn dismiss_contact_notification(
  92        &self,
  93        responder_id: UserId,
  94        requester_id: UserId,
  95    ) -> Result<()>;
  96    async fn respond_to_contact_request(
  97        &self,
  98        responder_id: UserId,
  99        requester_id: UserId,
 100        accept: bool,
 101    ) -> Result<()>;
 102
 103    async fn create_access_token_hash(
 104        &self,
 105        user_id: UserId,
 106        access_token_hash: &str,
 107        max_access_token_count: usize,
 108    ) -> Result<()>;
 109    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
 110    #[cfg(any(test, feature = "seed-support"))]
 111
 112    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
 113    #[cfg(any(test, feature = "seed-support"))]
 114    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
 115    #[cfg(any(test, feature = "seed-support"))]
 116    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
 117    #[cfg(any(test, feature = "seed-support"))]
 118    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
 119    #[cfg(any(test, feature = "seed-support"))]
 120
 121    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
 122    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
 123    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
 124        -> Result<bool>;
 125    #[cfg(any(test, feature = "seed-support"))]
 126    async fn add_channel_member(
 127        &self,
 128        channel_id: ChannelId,
 129        user_id: UserId,
 130        is_admin: bool,
 131    ) -> Result<()>;
 132    async fn create_channel_message(
 133        &self,
 134        channel_id: ChannelId,
 135        sender_id: UserId,
 136        body: &str,
 137        timestamp: OffsetDateTime,
 138        nonce: u128,
 139    ) -> Result<MessageId>;
 140    async fn get_channel_messages(
 141        &self,
 142        channel_id: ChannelId,
 143        count: usize,
 144        before_id: Option<MessageId>,
 145    ) -> Result<Vec<ChannelMessage>>;
 146    #[cfg(test)]
 147    async fn teardown(&self, url: &str);
 148    #[cfg(test)]
 149    fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
 150}
 151
 152pub struct PostgresDb {
 153    pool: sqlx::PgPool,
 154}
 155
 156impl PostgresDb {
 157    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 158        let pool = DbOptions::new()
 159            .max_connections(max_connections)
 160            .connect(&url)
 161            .await
 162            .context("failed to connect to postgres database")?;
 163        Ok(Self { pool })
 164    }
 165}
 166
 167#[async_trait]
 168impl Db for PostgresDb {
 169    // users
 170
 171    async fn create_user(
 172        &self,
 173        github_login: &str,
 174        email_address: Option<&str>,
 175        admin: bool,
 176    ) -> Result<UserId> {
 177        let query = "
 178            INSERT INTO users (github_login, email_address, admin)
 179            VALUES ($1, $2, $3)
 180            ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 181            RETURNING id
 182        ";
 183        Ok(sqlx::query_scalar(query)
 184            .bind(github_login)
 185            .bind(email_address)
 186            .bind(admin)
 187            .fetch_one(&self.pool)
 188            .await
 189            .map(UserId)?)
 190    }
 191
 192    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 193        let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
 194        Ok(sqlx::query_as(query)
 195            .bind(limit as i32)
 196            .bind((page * limit) as i32)
 197            .fetch_all(&self.pool)
 198            .await?)
 199    }
 200
 201    async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
 202        let mut query = QueryBuilder::new(
 203            "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
 204        );
 205        query.push_values(
 206            users,
 207            |mut query, (github_login, email_address, invite_count)| {
 208                query
 209                    .push_bind(github_login)
 210                    .push_bind(email_address)
 211                    .push_bind(false)
 212                    .push_bind(nanoid!(16))
 213                    .push_bind(invite_count as i32);
 214            },
 215        );
 216        query.push(
 217            "
 218            ON CONFLICT (github_login) DO UPDATE SET
 219                github_login = excluded.github_login,
 220                invite_count = excluded.invite_count,
 221                invite_code = CASE WHEN users.invite_code IS NULL
 222                                   THEN excluded.invite_code
 223                                   ELSE users.invite_code
 224                              END
 225            RETURNING id
 226            ",
 227        );
 228
 229        let rows = query.build().fetch_all(&self.pool).await?;
 230        Ok(rows
 231            .into_iter()
 232            .filter_map(|row| row.try_get::<UserId, _>(0).ok())
 233            .collect())
 234    }
 235
 236    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 237        let like_string = fuzzy_like_string(name_query);
 238        let query = "
 239            SELECT users.*
 240            FROM users
 241            WHERE github_login ILIKE $1
 242            ORDER BY github_login <-> $2
 243            LIMIT $3
 244        ";
 245        Ok(sqlx::query_as(query)
 246            .bind(like_string)
 247            .bind(name_query)
 248            .bind(limit as i32)
 249            .fetch_all(&self.pool)
 250            .await?)
 251    }
 252
 253    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
 254        let users = self.get_users_by_ids(vec![id]).await?;
 255        Ok(users.into_iter().next())
 256    }
 257
 258    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
 259        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
 260        let query = "
 261            SELECT users.*
 262            FROM users
 263            WHERE users.id = ANY ($1)
 264        ";
 265        Ok(sqlx::query_as(query)
 266            .bind(&ids)
 267            .fetch_all(&self.pool)
 268            .await?)
 269    }
 270
 271    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
 272        let query = format!(
 273            "
 274            SELECT users.*
 275            FROM users
 276            WHERE invite_count = 0
 277            AND inviter_id IS{} NULL
 278            ",
 279            if invited_by_another_user { " NOT" } else { "" }
 280        );
 281
 282        Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
 283    }
 284
 285    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
 286        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
 287        Ok(sqlx::query_as(query)
 288            .bind(github_login)
 289            .fetch_optional(&self.pool)
 290            .await?)
 291    }
 292
 293    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 294        let query = "UPDATE users SET admin = $1 WHERE id = $2";
 295        Ok(sqlx::query(query)
 296            .bind(is_admin)
 297            .bind(id.0)
 298            .execute(&self.pool)
 299            .await
 300            .map(drop)?)
 301    }
 302
 303    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 304        let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
 305        Ok(sqlx::query(query)
 306            .bind(connected_once)
 307            .bind(id.0)
 308            .execute(&self.pool)
 309            .await
 310            .map(drop)?)
 311    }
 312
 313    async fn destroy_user(&self, id: UserId) -> Result<()> {
 314        let query = "DELETE FROM access_tokens WHERE user_id = $1;";
 315        sqlx::query(query)
 316            .bind(id.0)
 317            .execute(&self.pool)
 318            .await
 319            .map(drop)?;
 320        let query = "DELETE FROM users WHERE id = $1;";
 321        Ok(sqlx::query(query)
 322            .bind(id.0)
 323            .execute(&self.pool)
 324            .await
 325            .map(drop)?)
 326    }
 327
 328    // invite codes
 329
 330    async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
 331        let mut tx = self.pool.begin().await?;
 332        if count > 0 {
 333            sqlx::query(
 334                "
 335                UPDATE users
 336                SET invite_code = $1
 337                WHERE id = $2 AND invite_code IS NULL
 338            ",
 339            )
 340            .bind(nanoid!(16))
 341            .bind(id)
 342            .execute(&mut tx)
 343            .await?;
 344        }
 345
 346        sqlx::query(
 347            "
 348            UPDATE users
 349            SET invite_count = $1
 350            WHERE id = $2
 351            ",
 352        )
 353        .bind(count as i32)
 354        .bind(id)
 355        .execute(&mut tx)
 356        .await?;
 357        tx.commit().await?;
 358        Ok(())
 359    }
 360
 361    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 362        let result: Option<(String, i32)> = sqlx::query_as(
 363            "
 364                SELECT invite_code, invite_count
 365                FROM users
 366                WHERE id = $1 AND invite_code IS NOT NULL 
 367            ",
 368        )
 369        .bind(id)
 370        .fetch_optional(&self.pool)
 371        .await?;
 372        if let Some((code, count)) = result {
 373            Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
 374        } else {
 375            Ok(None)
 376        }
 377    }
 378
 379    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 380        sqlx::query_as(
 381            "
 382                SELECT *
 383                FROM users
 384                WHERE invite_code = $1
 385            ",
 386        )
 387        .bind(code)
 388        .fetch_optional(&self.pool)
 389        .await?
 390        .ok_or_else(|| {
 391            Error::Http(
 392                StatusCode::NOT_FOUND,
 393                "that invite code does not exist".to_string(),
 394            )
 395        })
 396    }
 397
 398    async fn redeem_invite_code(
 399        &self,
 400        code: &str,
 401        login: &str,
 402        email_address: Option<&str>,
 403    ) -> Result<UserId> {
 404        let mut tx = self.pool.begin().await?;
 405
 406        let inviter_id: Option<UserId> = sqlx::query_scalar(
 407            "
 408                UPDATE users
 409                SET invite_count = invite_count - 1
 410                WHERE
 411                    invite_code = $1 AND
 412                    invite_count > 0
 413                RETURNING id
 414            ",
 415        )
 416        .bind(code)
 417        .fetch_optional(&mut tx)
 418        .await?;
 419
 420        let inviter_id = match inviter_id {
 421            Some(inviter_id) => inviter_id,
 422            None => {
 423                if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
 424                    .bind(code)
 425                    .fetch_optional(&mut tx)
 426                    .await?
 427                    .is_some()
 428                {
 429                    Err(Error::Http(
 430                        StatusCode::UNAUTHORIZED,
 431                        "no invites remaining".to_string(),
 432                    ))?
 433                } else {
 434                    Err(Error::Http(
 435                        StatusCode::NOT_FOUND,
 436                        "invite code not found".to_string(),
 437                    ))?
 438                }
 439            }
 440        };
 441
 442        let invitee_id = sqlx::query_scalar(
 443            "
 444                INSERT INTO users
 445                    (github_login, email_address, admin, inviter_id)
 446                VALUES
 447                    ($1, $2, 'f', $3)
 448                RETURNING id
 449            ",
 450        )
 451        .bind(login)
 452        .bind(email_address)
 453        .bind(inviter_id)
 454        .fetch_one(&mut tx)
 455        .await
 456        .map(UserId)?;
 457
 458        sqlx::query(
 459            "
 460                INSERT INTO contacts
 461                    (user_id_a, user_id_b, a_to_b, should_notify, accepted)
 462                VALUES
 463                    ($1, $2, 't', 't', 't')
 464            ",
 465        )
 466        .bind(inviter_id)
 467        .bind(invitee_id)
 468        .execute(&mut tx)
 469        .await?;
 470
 471        tx.commit().await?;
 472        Ok(invitee_id)
 473    }
 474
 475    // projects
 476
 477    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
 478        Ok(sqlx::query_scalar(
 479            "
 480            INSERT INTO projects(host_user_id)
 481            VALUES ($1)
 482            RETURNING id
 483            ",
 484        )
 485        .bind(host_user_id)
 486        .fetch_one(&self.pool)
 487        .await
 488        .map(ProjectId)?)
 489    }
 490
 491    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
 492        sqlx::query(
 493            "
 494            UPDATE projects
 495            SET unregistered = 't'
 496            WHERE id = $1
 497            ",
 498        )
 499        .bind(project_id)
 500        .execute(&self.pool)
 501        .await?;
 502        Ok(())
 503    }
 504
 505    async fn update_worktree_extensions(
 506        &self,
 507        project_id: ProjectId,
 508        worktree_id: u64,
 509        extensions: HashMap<String, u32>,
 510    ) -> Result<()> {
 511        if extensions.is_empty() {
 512            return Ok(());
 513        }
 514
 515        let mut query = QueryBuilder::new(
 516            "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
 517        );
 518        query.push_values(extensions, |mut query, (extension, count)| {
 519            query
 520                .push_bind(project_id)
 521                .push_bind(worktree_id as i32)
 522                .push_bind(extension)
 523                .push_bind(count as i32);
 524        });
 525        query.push(
 526            "
 527            ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
 528            count = excluded.count
 529            ",
 530        );
 531        query.build().execute(&self.pool).await?;
 532
 533        Ok(())
 534    }
 535
 536    async fn get_project_extensions(
 537        &self,
 538        project_id: ProjectId,
 539    ) -> Result<HashMap<u64, HashMap<String, usize>>> {
 540        #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
 541        struct WorktreeExtension {
 542            worktree_id: i32,
 543            extension: String,
 544            count: i32,
 545        }
 546
 547        let query = "
 548            SELECT worktree_id, extension, count
 549            FROM worktree_extensions
 550            WHERE project_id = $1
 551        ";
 552        let counts = sqlx::query_as::<_, WorktreeExtension>(query)
 553            .bind(&project_id)
 554            .fetch_all(&self.pool)
 555            .await?;
 556
 557        let mut extension_counts = HashMap::default();
 558        for count in counts {
 559            extension_counts
 560                .entry(count.worktree_id as u64)
 561                .or_insert(HashMap::default())
 562                .insert(count.extension, count.count as usize);
 563        }
 564        Ok(extension_counts)
 565    }
 566
 567    async fn record_user_activity(
 568        &self,
 569        time_period: Range<OffsetDateTime>,
 570        projects: &[(UserId, ProjectId)],
 571    ) -> Result<()> {
 572        let query = "
 573            INSERT INTO project_activity_periods
 574            (ended_at, duration_millis, user_id, project_id)
 575            VALUES
 576            ($1, $2, $3, $4);
 577        ";
 578
 579        let mut tx = self.pool.begin().await?;
 580        let duration_millis =
 581            ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
 582        for (user_id, project_id) in projects {
 583            sqlx::query(query)
 584                .bind(time_period.end)
 585                .bind(duration_millis)
 586                .bind(user_id)
 587                .bind(project_id)
 588                .execute(&mut tx)
 589                .await?;
 590        }
 591        tx.commit().await?;
 592
 593        Ok(())
 594    }
 595
 596    async fn get_top_users_activity_summary(
 597        &self,
 598        time_period: Range<OffsetDateTime>,
 599        max_user_count: usize,
 600    ) -> Result<Vec<UserActivitySummary>> {
 601        let query = "
 602            WITH
 603                project_durations AS (
 604                    SELECT user_id, project_id, SUM(duration_millis) AS project_duration
 605                    FROM project_activity_periods
 606                    WHERE $1 < ended_at AND ended_at <= $2
 607                    GROUP BY user_id, project_id
 608                ),
 609                user_durations AS (
 610                    SELECT user_id, SUM(project_duration) as total_duration
 611                    FROM project_durations
 612                    GROUP BY user_id
 613                    ORDER BY total_duration DESC
 614                    LIMIT $3
 615                )
 616            SELECT user_durations.user_id, users.github_login, project_id, project_duration
 617            FROM user_durations, project_durations, users
 618            WHERE
 619                user_durations.user_id = project_durations.user_id AND
 620                user_durations.user_id = users.id
 621            ORDER BY total_duration DESC, user_id ASC
 622        ";
 623
 624        let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
 625            .bind(time_period.start)
 626            .bind(time_period.end)
 627            .bind(max_user_count as i32)
 628            .fetch(&self.pool);
 629
 630        let mut result = Vec::<UserActivitySummary>::new();
 631        while let Some(row) = rows.next().await {
 632            let (user_id, github_login, project_id, duration_millis) = row?;
 633            let project_id = project_id;
 634            let duration = Duration::from_millis(duration_millis as u64);
 635            if let Some(last_summary) = result.last_mut() {
 636                if last_summary.id == user_id {
 637                    last_summary.project_activity.push((project_id, duration));
 638                    continue;
 639                }
 640            }
 641            result.push(UserActivitySummary {
 642                id: user_id,
 643                project_activity: vec![(project_id, duration)],
 644                github_login,
 645            });
 646        }
 647
 648        Ok(result)
 649    }
 650
 651    async fn get_user_activity_timeline(
 652        &self,
 653        time_period: Range<OffsetDateTime>,
 654        user_id: UserId,
 655    ) -> Result<Vec<UserActivityPeriod>> {
 656        const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
 657
 658        let query = "
 659            SELECT
 660                project_activity_periods.ended_at,
 661                project_activity_periods.duration_millis,
 662                project_activity_periods.project_id,
 663                worktree_extensions.extension,
 664                worktree_extensions.count
 665            FROM project_activity_periods
 666            LEFT OUTER JOIN
 667                worktree_extensions
 668            ON
 669                project_activity_periods.project_id = worktree_extensions.project_id
 670            WHERE
 671                project_activity_periods.user_id = $1 AND
 672                $2 < project_activity_periods.ended_at AND
 673                project_activity_periods.ended_at <= $3
 674            ORDER BY project_activity_periods.id ASC
 675        ";
 676
 677        let mut rows = sqlx::query_as::<
 678            _,
 679            (
 680                PrimitiveDateTime,
 681                i32,
 682                ProjectId,
 683                Option<String>,
 684                Option<i32>,
 685            ),
 686        >(query)
 687        .bind(user_id)
 688        .bind(time_period.start)
 689        .bind(time_period.end)
 690        .fetch(&self.pool);
 691
 692        let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
 693        while let Some(row) = rows.next().await {
 694            let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
 695            let ended_at = ended_at.assume_utc();
 696            let duration = Duration::from_millis(duration_millis as u64);
 697            let started_at = ended_at - duration;
 698            let project_time_periods = time_periods.entry(project_id).or_default();
 699
 700            if let Some(prev_duration) = project_time_periods.last_mut() {
 701                if started_at <= prev_duration.end + COALESCE_THRESHOLD
 702                    && ended_at >= prev_duration.start
 703                {
 704                    prev_duration.end = cmp::max(prev_duration.end, ended_at);
 705                } else {
 706                    project_time_periods.push(UserActivityPeriod {
 707                        project_id,
 708                        start: started_at,
 709                        end: ended_at,
 710                        extensions: Default::default(),
 711                    });
 712                }
 713            } else {
 714                project_time_periods.push(UserActivityPeriod {
 715                    project_id,
 716                    start: started_at,
 717                    end: ended_at,
 718                    extensions: Default::default(),
 719                });
 720            }
 721
 722            if let Some((extension, extension_count)) = extension.zip(extension_count) {
 723                project_time_periods
 724                    .last_mut()
 725                    .unwrap()
 726                    .extensions
 727                    .insert(extension, extension_count as usize);
 728            }
 729        }
 730
 731        let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
 732        durations.sort_unstable_by_key(|duration| duration.start);
 733        Ok(durations)
 734    }
 735
 736    // contacts
 737
 738    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 739        let query = "
 740            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
 741            FROM contacts
 742            WHERE user_id_a = $1 OR user_id_b = $1;
 743        ";
 744
 745        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
 746            .bind(user_id)
 747            .fetch(&self.pool);
 748
 749        let mut contacts = vec![Contact::Accepted {
 750            user_id,
 751            should_notify: false,
 752        }];
 753        while let Some(row) = rows.next().await {
 754            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 755
 756            if user_id_a == user_id {
 757                if accepted {
 758                    contacts.push(Contact::Accepted {
 759                        user_id: user_id_b,
 760                        should_notify: should_notify && a_to_b,
 761                    });
 762                } else if a_to_b {
 763                    contacts.push(Contact::Outgoing { user_id: user_id_b })
 764                } else {
 765                    contacts.push(Contact::Incoming {
 766                        user_id: user_id_b,
 767                        should_notify,
 768                    });
 769                }
 770            } else {
 771                if accepted {
 772                    contacts.push(Contact::Accepted {
 773                        user_id: user_id_a,
 774                        should_notify: should_notify && !a_to_b,
 775                    });
 776                } else if a_to_b {
 777                    contacts.push(Contact::Incoming {
 778                        user_id: user_id_a,
 779                        should_notify,
 780                    });
 781                } else {
 782                    contacts.push(Contact::Outgoing { user_id: user_id_a });
 783                }
 784            }
 785        }
 786
 787        contacts.sort_unstable_by_key(|contact| contact.user_id());
 788
 789        Ok(contacts)
 790    }
 791
 792    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 793        let (id_a, id_b) = if user_id_1 < user_id_2 {
 794            (user_id_1, user_id_2)
 795        } else {
 796            (user_id_2, user_id_1)
 797        };
 798
 799        let query = "
 800            SELECT 1 FROM contacts
 801            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
 802            LIMIT 1
 803        ";
 804        Ok(sqlx::query_scalar::<_, i32>(query)
 805            .bind(id_a.0)
 806            .bind(id_b.0)
 807            .fetch_optional(&self.pool)
 808            .await?
 809            .is_some())
 810    }
 811
 812    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 813        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 814            (sender_id, receiver_id, true)
 815        } else {
 816            (receiver_id, sender_id, false)
 817        };
 818        let query = "
 819            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
 820            VALUES ($1, $2, $3, 'f', 't')
 821            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
 822            SET
 823                accepted = 't',
 824                should_notify = 'f'
 825            WHERE
 826                NOT contacts.accepted AND
 827                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
 828                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
 829        ";
 830        let result = sqlx::query(query)
 831            .bind(id_a.0)
 832            .bind(id_b.0)
 833            .bind(a_to_b)
 834            .execute(&self.pool)
 835            .await?;
 836
 837        if result.rows_affected() == 1 {
 838            Ok(())
 839        } else {
 840            Err(anyhow!("contact already requested"))?
 841        }
 842    }
 843
 844    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 845        let (id_a, id_b) = if responder_id < requester_id {
 846            (responder_id, requester_id)
 847        } else {
 848            (requester_id, responder_id)
 849        };
 850        let query = "
 851            DELETE FROM contacts
 852            WHERE user_id_a = $1 AND user_id_b = $2;
 853        ";
 854        let result = sqlx::query(query)
 855            .bind(id_a.0)
 856            .bind(id_b.0)
 857            .execute(&self.pool)
 858            .await?;
 859
 860        if result.rows_affected() == 1 {
 861            Ok(())
 862        } else {
 863            Err(anyhow!("no such contact"))?
 864        }
 865    }
 866
 867    async fn dismiss_contact_notification(
 868        &self,
 869        user_id: UserId,
 870        contact_user_id: UserId,
 871    ) -> Result<()> {
 872        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 873            (user_id, contact_user_id, true)
 874        } else {
 875            (contact_user_id, user_id, false)
 876        };
 877
 878        let query = "
 879            UPDATE contacts
 880            SET should_notify = 'f'
 881            WHERE
 882                user_id_a = $1 AND user_id_b = $2 AND
 883                (
 884                    (a_to_b = $3 AND accepted) OR
 885                    (a_to_b != $3 AND NOT accepted)
 886                );
 887        ";
 888
 889        let result = sqlx::query(query)
 890            .bind(id_a.0)
 891            .bind(id_b.0)
 892            .bind(a_to_b)
 893            .execute(&self.pool)
 894            .await?;
 895
 896        if result.rows_affected() == 0 {
 897            Err(anyhow!("no such contact request"))?;
 898        }
 899
 900        Ok(())
 901    }
 902
 903    async fn respond_to_contact_request(
 904        &self,
 905        responder_id: UserId,
 906        requester_id: UserId,
 907        accept: bool,
 908    ) -> Result<()> {
 909        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 910            (responder_id, requester_id, false)
 911        } else {
 912            (requester_id, responder_id, true)
 913        };
 914        let result = if accept {
 915            let query = "
 916                UPDATE contacts
 917                SET accepted = 't', should_notify = 't'
 918                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
 919            ";
 920            sqlx::query(query)
 921                .bind(id_a.0)
 922                .bind(id_b.0)
 923                .bind(a_to_b)
 924                .execute(&self.pool)
 925                .await?
 926        } else {
 927            let query = "
 928                DELETE FROM contacts
 929                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
 930            ";
 931            sqlx::query(query)
 932                .bind(id_a.0)
 933                .bind(id_b.0)
 934                .bind(a_to_b)
 935                .execute(&self.pool)
 936                .await?
 937        };
 938        if result.rows_affected() == 1 {
 939            Ok(())
 940        } else {
 941            Err(anyhow!("no such contact request"))?
 942        }
 943    }
 944
 945    // access tokens
 946
 947    async fn create_access_token_hash(
 948        &self,
 949        user_id: UserId,
 950        access_token_hash: &str,
 951        max_access_token_count: usize,
 952    ) -> Result<()> {
 953        let insert_query = "
 954            INSERT INTO access_tokens (user_id, hash)
 955            VALUES ($1, $2);
 956        ";
 957        let cleanup_query = "
 958            DELETE FROM access_tokens
 959            WHERE id IN (
 960                SELECT id from access_tokens
 961                WHERE user_id = $1
 962                ORDER BY id DESC
 963                OFFSET $3
 964            )
 965        ";
 966
 967        let mut tx = self.pool.begin().await?;
 968        sqlx::query(insert_query)
 969            .bind(user_id.0)
 970            .bind(access_token_hash)
 971            .execute(&mut tx)
 972            .await?;
 973        sqlx::query(cleanup_query)
 974            .bind(user_id.0)
 975            .bind(access_token_hash)
 976            .bind(max_access_token_count as i32)
 977            .execute(&mut tx)
 978            .await?;
 979        Ok(tx.commit().await?)
 980    }
 981
 982    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 983        let query = "
 984            SELECT hash
 985            FROM access_tokens
 986            WHERE user_id = $1
 987            ORDER BY id DESC
 988        ";
 989        Ok(sqlx::query_scalar(query)
 990            .bind(user_id.0)
 991            .fetch_all(&self.pool)
 992            .await?)
 993    }
 994
 995    // orgs
 996
 997    #[allow(unused)] // Help rust-analyzer
 998    #[cfg(any(test, feature = "seed-support"))]
 999    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1000        let query = "
1001            SELECT *
1002            FROM orgs
1003            WHERE slug = $1
1004        ";
1005        Ok(sqlx::query_as(query)
1006            .bind(slug)
1007            .fetch_optional(&self.pool)
1008            .await?)
1009    }
1010
1011    #[cfg(any(test, feature = "seed-support"))]
1012    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1013        let query = "
1014            INSERT INTO orgs (name, slug)
1015            VALUES ($1, $2)
1016            RETURNING id
1017        ";
1018        Ok(sqlx::query_scalar(query)
1019            .bind(name)
1020            .bind(slug)
1021            .fetch_one(&self.pool)
1022            .await
1023            .map(OrgId)?)
1024    }
1025
1026    #[cfg(any(test, feature = "seed-support"))]
1027    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1028        let query = "
1029            INSERT INTO org_memberships (org_id, user_id, admin)
1030            VALUES ($1, $2, $3)
1031            ON CONFLICT DO NOTHING
1032        ";
1033        Ok(sqlx::query(query)
1034            .bind(org_id.0)
1035            .bind(user_id.0)
1036            .bind(is_admin)
1037            .execute(&self.pool)
1038            .await
1039            .map(drop)?)
1040    }
1041
1042    // channels
1043
1044    #[cfg(any(test, feature = "seed-support"))]
1045    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1046        let query = "
1047            INSERT INTO channels (owner_id, owner_is_user, name)
1048            VALUES ($1, false, $2)
1049            RETURNING id
1050        ";
1051        Ok(sqlx::query_scalar(query)
1052            .bind(org_id.0)
1053            .bind(name)
1054            .fetch_one(&self.pool)
1055            .await
1056            .map(ChannelId)?)
1057    }
1058
1059    #[allow(unused)] // Help rust-analyzer
1060    #[cfg(any(test, feature = "seed-support"))]
1061    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1062        let query = "
1063            SELECT *
1064            FROM channels
1065            WHERE
1066                channels.owner_is_user = false AND
1067                channels.owner_id = $1
1068        ";
1069        Ok(sqlx::query_as(query)
1070            .bind(org_id.0)
1071            .fetch_all(&self.pool)
1072            .await?)
1073    }
1074
1075    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1076        let query = "
1077            SELECT
1078                channels.*
1079            FROM
1080                channel_memberships, channels
1081            WHERE
1082                channel_memberships.user_id = $1 AND
1083                channel_memberships.channel_id = channels.id
1084        ";
1085        Ok(sqlx::query_as(query)
1086            .bind(user_id.0)
1087            .fetch_all(&self.pool)
1088            .await?)
1089    }
1090
1091    async fn can_user_access_channel(
1092        &self,
1093        user_id: UserId,
1094        channel_id: ChannelId,
1095    ) -> Result<bool> {
1096        let query = "
1097            SELECT id
1098            FROM channel_memberships
1099            WHERE user_id = $1 AND channel_id = $2
1100            LIMIT 1
1101        ";
1102        Ok(sqlx::query_scalar::<_, i32>(query)
1103            .bind(user_id.0)
1104            .bind(channel_id.0)
1105            .fetch_optional(&self.pool)
1106            .await
1107            .map(|e| e.is_some())?)
1108    }
1109
1110    #[cfg(any(test, feature = "seed-support"))]
1111    async fn add_channel_member(
1112        &self,
1113        channel_id: ChannelId,
1114        user_id: UserId,
1115        is_admin: bool,
1116    ) -> Result<()> {
1117        let query = "
1118            INSERT INTO channel_memberships (channel_id, user_id, admin)
1119            VALUES ($1, $2, $3)
1120            ON CONFLICT DO NOTHING
1121        ";
1122        Ok(sqlx::query(query)
1123            .bind(channel_id.0)
1124            .bind(user_id.0)
1125            .bind(is_admin)
1126            .execute(&self.pool)
1127            .await
1128            .map(drop)?)
1129    }
1130
1131    // messages
1132
1133    async fn create_channel_message(
1134        &self,
1135        channel_id: ChannelId,
1136        sender_id: UserId,
1137        body: &str,
1138        timestamp: OffsetDateTime,
1139        nonce: u128,
1140    ) -> Result<MessageId> {
1141        let query = "
1142            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1143            VALUES ($1, $2, $3, $4, $5)
1144            ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1145            RETURNING id
1146        ";
1147        Ok(sqlx::query_scalar(query)
1148            .bind(channel_id.0)
1149            .bind(sender_id.0)
1150            .bind(body)
1151            .bind(timestamp)
1152            .bind(Uuid::from_u128(nonce))
1153            .fetch_one(&self.pool)
1154            .await
1155            .map(MessageId)?)
1156    }
1157
1158    async fn get_channel_messages(
1159        &self,
1160        channel_id: ChannelId,
1161        count: usize,
1162        before_id: Option<MessageId>,
1163    ) -> Result<Vec<ChannelMessage>> {
1164        let query = r#"
1165            SELECT * FROM (
1166                SELECT
1167                    id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1168                FROM
1169                    channel_messages
1170                WHERE
1171                    channel_id = $1 AND
1172                    id < $2
1173                ORDER BY id DESC
1174                LIMIT $3
1175            ) as recent_messages
1176            ORDER BY id ASC
1177        "#;
1178        Ok(sqlx::query_as(query)
1179            .bind(channel_id.0)
1180            .bind(before_id.unwrap_or(MessageId::MAX))
1181            .bind(count as i64)
1182            .fetch_all(&self.pool)
1183            .await?)
1184    }
1185
1186    #[cfg(test)]
1187    async fn teardown(&self, url: &str) {
1188        use util::ResultExt;
1189
1190        let query = "
1191            SELECT pg_terminate_backend(pg_stat_activity.pid)
1192            FROM pg_stat_activity
1193            WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1194        ";
1195        sqlx::query(query).execute(&self.pool).await.log_err();
1196        self.pool.close().await;
1197        <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1198            .await
1199            .log_err();
1200    }
1201
1202    #[cfg(test)]
1203    fn as_fake(&self) -> Option<&tests::FakeDb> {
1204        None
1205    }
1206}
1207
1208macro_rules! id_type {
1209    ($name:ident) => {
1210        #[derive(
1211            Clone,
1212            Copy,
1213            Debug,
1214            Default,
1215            PartialEq,
1216            Eq,
1217            PartialOrd,
1218            Ord,
1219            Hash,
1220            sqlx::Type,
1221            Serialize,
1222            Deserialize,
1223        )]
1224        #[sqlx(transparent)]
1225        #[serde(transparent)]
1226        pub struct $name(pub i32);
1227
1228        impl $name {
1229            #[allow(unused)]
1230            pub const MAX: Self = Self(i32::MAX);
1231
1232            #[allow(unused)]
1233            pub fn from_proto(value: u64) -> Self {
1234                Self(value as i32)
1235            }
1236
1237            #[allow(unused)]
1238            pub fn to_proto(&self) -> u64 {
1239                self.0 as u64
1240            }
1241        }
1242
1243        impl std::fmt::Display for $name {
1244            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1245                self.0.fmt(f)
1246            }
1247        }
1248    };
1249}
1250
1251id_type!(UserId);
1252#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1253pub struct User {
1254    pub id: UserId,
1255    pub github_login: String,
1256    pub email_address: Option<String>,
1257    pub admin: bool,
1258    pub invite_code: Option<String>,
1259    pub invite_count: i32,
1260    pub connected_once: bool,
1261}
1262
1263id_type!(ProjectId);
1264#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1265pub struct Project {
1266    pub id: ProjectId,
1267    pub host_user_id: UserId,
1268    pub unregistered: bool,
1269}
1270
1271#[derive(Clone, Debug, PartialEq, Serialize)]
1272pub struct UserActivitySummary {
1273    pub id: UserId,
1274    pub github_login: String,
1275    pub project_activity: Vec<(ProjectId, Duration)>,
1276}
1277
1278#[derive(Clone, Debug, PartialEq, Serialize)]
1279pub struct UserActivityPeriod {
1280    project_id: ProjectId,
1281    #[serde(with = "time::serde::iso8601")]
1282    start: OffsetDateTime,
1283    #[serde(with = "time::serde::iso8601")]
1284    end: OffsetDateTime,
1285    extensions: HashMap<String, usize>,
1286}
1287
1288id_type!(OrgId);
1289#[derive(FromRow)]
1290pub struct Org {
1291    pub id: OrgId,
1292    pub name: String,
1293    pub slug: String,
1294}
1295
1296id_type!(ChannelId);
1297#[derive(Clone, Debug, FromRow, Serialize)]
1298pub struct Channel {
1299    pub id: ChannelId,
1300    pub name: String,
1301    pub owner_id: i32,
1302    pub owner_is_user: bool,
1303}
1304
1305id_type!(MessageId);
1306#[derive(Clone, Debug, FromRow)]
1307pub struct ChannelMessage {
1308    pub id: MessageId,
1309    pub channel_id: ChannelId,
1310    pub sender_id: UserId,
1311    pub body: String,
1312    pub sent_at: OffsetDateTime,
1313    pub nonce: Uuid,
1314}
1315
1316#[derive(Clone, Debug, PartialEq, Eq)]
1317pub enum Contact {
1318    Accepted {
1319        user_id: UserId,
1320        should_notify: bool,
1321    },
1322    Outgoing {
1323        user_id: UserId,
1324    },
1325    Incoming {
1326        user_id: UserId,
1327        should_notify: bool,
1328    },
1329}
1330
1331impl Contact {
1332    pub fn user_id(&self) -> UserId {
1333        match self {
1334            Contact::Accepted { user_id, .. } => *user_id,
1335            Contact::Outgoing { user_id } => *user_id,
1336            Contact::Incoming { user_id, .. } => *user_id,
1337        }
1338    }
1339}
1340
1341#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1342pub struct IncomingContactRequest {
1343    pub requester_id: UserId,
1344    pub should_notify: bool,
1345}
1346
1347fn fuzzy_like_string(string: &str) -> String {
1348    let mut result = String::with_capacity(string.len() * 2 + 1);
1349    for c in string.chars() {
1350        if c.is_alphanumeric() {
1351            result.push('%');
1352            result.push(c);
1353        }
1354    }
1355    result.push('%');
1356    result
1357}
1358
1359#[cfg(test)]
1360pub mod tests {
1361    use super::*;
1362    use anyhow::anyhow;
1363    use collections::BTreeMap;
1364    use gpui::executor::Background;
1365    use lazy_static::lazy_static;
1366    use parking_lot::Mutex;
1367    use rand::prelude::*;
1368    use sqlx::{
1369        migrate::{MigrateDatabase, Migrator},
1370        Postgres,
1371    };
1372    use std::{path::Path, sync::Arc};
1373    use util::post_inc;
1374
1375    #[tokio::test(flavor = "multi_thread")]
1376    async fn test_get_users_by_ids() {
1377        for test_db in [
1378            TestDb::postgres().await,
1379            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1380        ] {
1381            let db = test_db.db();
1382
1383            let user = db.create_user("user", None, false).await.unwrap();
1384            let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1385            let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1386            let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1387
1388            assert_eq!(
1389                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1390                    .await
1391                    .unwrap(),
1392                vec![
1393                    User {
1394                        id: user,
1395                        github_login: "user".to_string(),
1396                        admin: false,
1397                        ..Default::default()
1398                    },
1399                    User {
1400                        id: friend1,
1401                        github_login: "friend-1".to_string(),
1402                        admin: false,
1403                        ..Default::default()
1404                    },
1405                    User {
1406                        id: friend2,
1407                        github_login: "friend-2".to_string(),
1408                        admin: false,
1409                        ..Default::default()
1410                    },
1411                    User {
1412                        id: friend3,
1413                        github_login: "friend-3".to_string(),
1414                        admin: false,
1415                        ..Default::default()
1416                    }
1417                ]
1418            );
1419        }
1420    }
1421
1422    #[tokio::test(flavor = "multi_thread")]
1423    async fn test_create_users() {
1424        let db = TestDb::postgres().await;
1425        let db = db.db();
1426
1427        // Create the first batch of users, ensuring invite counts are assigned
1428        // correctly and the respective invite codes are unique.
1429        let user_ids_batch_1 = db
1430            .create_users(vec![
1431                ("user1".to_string(), "hi@user1.com".to_string(), 5),
1432                ("user2".to_string(), "hi@user2.com".to_string(), 4),
1433                ("user3".to_string(), "hi@user3.com".to_string(), 3),
1434            ])
1435            .await
1436            .unwrap();
1437        assert_eq!(user_ids_batch_1.len(), 3);
1438
1439        let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1440        assert_eq!(users.len(), 3);
1441        assert_eq!(users[0].github_login, "user1");
1442        assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1443        assert_eq!(users[0].invite_count, 5);
1444        assert_eq!(users[1].github_login, "user2");
1445        assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1446        assert_eq!(users[1].invite_count, 4);
1447        assert_eq!(users[2].github_login, "user3");
1448        assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1449        assert_eq!(users[2].invite_count, 3);
1450
1451        let invite_code_1 = users[0].invite_code.clone().unwrap();
1452        let invite_code_2 = users[1].invite_code.clone().unwrap();
1453        let invite_code_3 = users[2].invite_code.clone().unwrap();
1454        assert_ne!(invite_code_1, invite_code_2);
1455        assert_ne!(invite_code_1, invite_code_3);
1456        assert_ne!(invite_code_2, invite_code_3);
1457
1458        // Create the second batch of users and include a user that is already in the database, ensuring
1459        // the invite count for the existing user is updated without changing their invite code.
1460        let user_ids_batch_2 = db
1461            .create_users(vec![
1462                ("user2".to_string(), "hi@user2.com".to_string(), 10),
1463                ("user4".to_string(), "hi@user4.com".to_string(), 2),
1464            ])
1465            .await
1466            .unwrap();
1467        assert_eq!(user_ids_batch_2.len(), 2);
1468        assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1469
1470        let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1471        assert_eq!(users.len(), 2);
1472        assert_eq!(users[0].github_login, "user2");
1473        assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1474        assert_eq!(users[0].invite_count, 10);
1475        assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1476        assert_eq!(users[1].github_login, "user4");
1477        assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1478        assert_eq!(users[1].invite_count, 2);
1479
1480        let invite_code_4 = users[1].invite_code.clone().unwrap();
1481        assert_ne!(invite_code_4, invite_code_1);
1482        assert_ne!(invite_code_4, invite_code_2);
1483        assert_ne!(invite_code_4, invite_code_3);
1484    }
1485
1486    #[tokio::test(flavor = "multi_thread")]
1487    async fn test_worktree_extensions() {
1488        let test_db = TestDb::postgres().await;
1489        let db = test_db.db();
1490
1491        let user = db.create_user("user_1", None, false).await.unwrap();
1492        let project = db.register_project(user).await.unwrap();
1493
1494        db.update_worktree_extensions(project, 100, Default::default())
1495            .await
1496            .unwrap();
1497        db.update_worktree_extensions(
1498            project,
1499            100,
1500            [("rs".to_string(), 5), ("md".to_string(), 3)]
1501                .into_iter()
1502                .collect(),
1503        )
1504        .await
1505        .unwrap();
1506        db.update_worktree_extensions(
1507            project,
1508            100,
1509            [("rs".to_string(), 6), ("md".to_string(), 5)]
1510                .into_iter()
1511                .collect(),
1512        )
1513        .await
1514        .unwrap();
1515        db.update_worktree_extensions(
1516            project,
1517            101,
1518            [("ts".to_string(), 2), ("md".to_string(), 1)]
1519                .into_iter()
1520                .collect(),
1521        )
1522        .await
1523        .unwrap();
1524
1525        assert_eq!(
1526            db.get_project_extensions(project).await.unwrap(),
1527            [
1528                (
1529                    100,
1530                    [("rs".into(), 6), ("md".into(), 5),]
1531                        .into_iter()
1532                        .collect::<HashMap<_, _>>()
1533                ),
1534                (
1535                    101,
1536                    [("ts".into(), 2), ("md".into(), 1),]
1537                        .into_iter()
1538                        .collect::<HashMap<_, _>>()
1539                )
1540            ]
1541            .into_iter()
1542            .collect()
1543        );
1544    }
1545
1546    #[tokio::test(flavor = "multi_thread")]
1547    async fn test_project_activity() {
1548        let test_db = TestDb::postgres().await;
1549        let db = test_db.db();
1550
1551        let user_1 = db.create_user("user_1", None, false).await.unwrap();
1552        let user_2 = db.create_user("user_2", None, false).await.unwrap();
1553        let user_3 = db.create_user("user_3", None, false).await.unwrap();
1554        let project_1 = db.register_project(user_1).await.unwrap();
1555        db.update_worktree_extensions(
1556            project_1,
1557            1,
1558            HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1559        )
1560        .await
1561        .unwrap();
1562        let project_2 = db.register_project(user_2).await.unwrap();
1563        let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1564
1565        // User 2 opens a project
1566        let t1 = t0 + Duration::from_secs(10);
1567        db.record_user_activity(t0..t1, &[(user_2, project_2)])
1568            .await
1569            .unwrap();
1570
1571        let t2 = t1 + Duration::from_secs(10);
1572        db.record_user_activity(t1..t2, &[(user_2, project_2)])
1573            .await
1574            .unwrap();
1575
1576        // User 1 joins the project
1577        let t3 = t2 + Duration::from_secs(10);
1578        db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1579            .await
1580            .unwrap();
1581
1582        // User 1 opens another project
1583        let t4 = t3 + Duration::from_secs(10);
1584        db.record_user_activity(
1585            t3..t4,
1586            &[
1587                (user_2, project_2),
1588                (user_1, project_2),
1589                (user_1, project_1),
1590            ],
1591        )
1592        .await
1593        .unwrap();
1594
1595        // User 3 joins that project
1596        let t5 = t4 + Duration::from_secs(10);
1597        db.record_user_activity(
1598            t4..t5,
1599            &[
1600                (user_2, project_2),
1601                (user_1, project_2),
1602                (user_1, project_1),
1603                (user_3, project_1),
1604            ],
1605        )
1606        .await
1607        .unwrap();
1608
1609        // User 2 leaves
1610        let t6 = t5 + Duration::from_secs(5);
1611        db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1612            .await
1613            .unwrap();
1614
1615        let t7 = t6 + Duration::from_secs(60);
1616        let t8 = t7 + Duration::from_secs(10);
1617        db.record_user_activity(t7..t8, &[(user_1, project_1)])
1618            .await
1619            .unwrap();
1620
1621        assert_eq!(
1622            db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1623            &[
1624                UserActivitySummary {
1625                    id: user_1,
1626                    github_login: "user_1".to_string(),
1627                    project_activity: vec![
1628                        (project_1, Duration::from_secs(25)),
1629                        (project_2, Duration::from_secs(30)),
1630                    ]
1631                },
1632                UserActivitySummary {
1633                    id: user_2,
1634                    github_login: "user_2".to_string(),
1635                    project_activity: vec![(project_2, Duration::from_secs(50))]
1636                },
1637                UserActivitySummary {
1638                    id: user_3,
1639                    github_login: "user_3".to_string(),
1640                    project_activity: vec![(project_1, Duration::from_secs(15))]
1641                },
1642            ]
1643        );
1644        assert_eq!(
1645            db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1646            &[
1647                UserActivityPeriod {
1648                    project_id: project_1,
1649                    start: t3,
1650                    end: t6,
1651                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1652                },
1653                UserActivityPeriod {
1654                    project_id: project_2,
1655                    start: t3,
1656                    end: t5,
1657                    extensions: Default::default(),
1658                },
1659            ]
1660        );
1661        assert_eq!(
1662            db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1663            &[
1664                UserActivityPeriod {
1665                    project_id: project_2,
1666                    start: t2,
1667                    end: t5,
1668                    extensions: Default::default(),
1669                },
1670                UserActivityPeriod {
1671                    project_id: project_1,
1672                    start: t3,
1673                    end: t6,
1674                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1675                },
1676                UserActivityPeriod {
1677                    project_id: project_1,
1678                    start: t7,
1679                    end: t8,
1680                    extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1681                },
1682            ]
1683        );
1684    }
1685
1686    #[tokio::test(flavor = "multi_thread")]
1687    async fn test_recent_channel_messages() {
1688        for test_db in [
1689            TestDb::postgres().await,
1690            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1691        ] {
1692            let db = test_db.db();
1693            let user = db.create_user("user", None, false).await.unwrap();
1694            let org = db.create_org("org", "org").await.unwrap();
1695            let channel = db.create_org_channel(org, "channel").await.unwrap();
1696            for i in 0..10 {
1697                db.create_channel_message(
1698                    channel,
1699                    user,
1700                    &i.to_string(),
1701                    OffsetDateTime::now_utc(),
1702                    i,
1703                )
1704                .await
1705                .unwrap();
1706            }
1707
1708            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1709            assert_eq!(
1710                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1711                ["5", "6", "7", "8", "9"]
1712            );
1713
1714            let prev_messages = db
1715                .get_channel_messages(channel, 4, Some(messages[0].id))
1716                .await
1717                .unwrap();
1718            assert_eq!(
1719                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1720                ["1", "2", "3", "4"]
1721            );
1722        }
1723    }
1724
1725    #[tokio::test(flavor = "multi_thread")]
1726    async fn test_channel_message_nonces() {
1727        for test_db in [
1728            TestDb::postgres().await,
1729            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1730        ] {
1731            let db = test_db.db();
1732            let user = db.create_user("user", None, false).await.unwrap();
1733            let org = db.create_org("org", "org").await.unwrap();
1734            let channel = db.create_org_channel(org, "channel").await.unwrap();
1735
1736            let msg1_id = db
1737                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1738                .await
1739                .unwrap();
1740            let msg2_id = db
1741                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1742                .await
1743                .unwrap();
1744            let msg3_id = db
1745                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1746                .await
1747                .unwrap();
1748            let msg4_id = db
1749                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1750                .await
1751                .unwrap();
1752
1753            assert_ne!(msg1_id, msg2_id);
1754            assert_eq!(msg1_id, msg3_id);
1755            assert_eq!(msg2_id, msg4_id);
1756        }
1757    }
1758
1759    #[tokio::test(flavor = "multi_thread")]
1760    async fn test_create_access_tokens() {
1761        let test_db = TestDb::postgres().await;
1762        let db = test_db.db();
1763        let user = db.create_user("the-user", None, false).await.unwrap();
1764
1765        db.create_access_token_hash(user, "h1", 3).await.unwrap();
1766        db.create_access_token_hash(user, "h2", 3).await.unwrap();
1767        assert_eq!(
1768            db.get_access_token_hashes(user).await.unwrap(),
1769            &["h2".to_string(), "h1".to_string()]
1770        );
1771
1772        db.create_access_token_hash(user, "h3", 3).await.unwrap();
1773        assert_eq!(
1774            db.get_access_token_hashes(user).await.unwrap(),
1775            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1776        );
1777
1778        db.create_access_token_hash(user, "h4", 3).await.unwrap();
1779        assert_eq!(
1780            db.get_access_token_hashes(user).await.unwrap(),
1781            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1782        );
1783
1784        db.create_access_token_hash(user, "h5", 3).await.unwrap();
1785        assert_eq!(
1786            db.get_access_token_hashes(user).await.unwrap(),
1787            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1788        );
1789    }
1790
1791    #[test]
1792    fn test_fuzzy_like_string() {
1793        assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1794        assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1795        assert_eq!(fuzzy_like_string(" z  "), "%z%");
1796    }
1797
1798    #[tokio::test(flavor = "multi_thread")]
1799    async fn test_fuzzy_search_users() {
1800        let test_db = TestDb::postgres().await;
1801        let db = test_db.db();
1802        for github_login in [
1803            "California",
1804            "colorado",
1805            "oregon",
1806            "washington",
1807            "florida",
1808            "delaware",
1809            "rhode-island",
1810        ] {
1811            db.create_user(github_login, None, false).await.unwrap();
1812        }
1813
1814        assert_eq!(
1815            fuzzy_search_user_names(db, "clr").await,
1816            &["colorado", "California"]
1817        );
1818        assert_eq!(
1819            fuzzy_search_user_names(db, "ro").await,
1820            &["rhode-island", "colorado", "oregon"],
1821        );
1822
1823        async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1824            db.fuzzy_search_users(query, 10)
1825                .await
1826                .unwrap()
1827                .into_iter()
1828                .map(|user| user.github_login)
1829                .collect::<Vec<_>>()
1830        }
1831    }
1832
1833    #[tokio::test(flavor = "multi_thread")]
1834    async fn test_add_contacts() {
1835        for test_db in [
1836            TestDb::postgres().await,
1837            TestDb::fake(Arc::new(gpui::executor::Background::new())),
1838        ] {
1839            let db = test_db.db();
1840
1841            let user_1 = db.create_user("user1", None, false).await.unwrap();
1842            let user_2 = db.create_user("user2", None, false).await.unwrap();
1843            let user_3 = db.create_user("user3", None, false).await.unwrap();
1844
1845            // User starts with no contacts
1846            assert_eq!(
1847                db.get_contacts(user_1).await.unwrap(),
1848                vec![Contact::Accepted {
1849                    user_id: user_1,
1850                    should_notify: false
1851                }],
1852            );
1853
1854            // User requests a contact. Both users see the pending request.
1855            db.send_contact_request(user_1, user_2).await.unwrap();
1856            assert!(!db.has_contact(user_1, user_2).await.unwrap());
1857            assert!(!db.has_contact(user_2, user_1).await.unwrap());
1858            assert_eq!(
1859                db.get_contacts(user_1).await.unwrap(),
1860                &[
1861                    Contact::Accepted {
1862                        user_id: user_1,
1863                        should_notify: false
1864                    },
1865                    Contact::Outgoing { user_id: user_2 }
1866                ],
1867            );
1868            assert_eq!(
1869                db.get_contacts(user_2).await.unwrap(),
1870                &[
1871                    Contact::Incoming {
1872                        user_id: user_1,
1873                        should_notify: true
1874                    },
1875                    Contact::Accepted {
1876                        user_id: user_2,
1877                        should_notify: false
1878                    },
1879                ]
1880            );
1881
1882            // User 2 dismisses the contact request notification without accepting or rejecting.
1883            // We shouldn't notify them again.
1884            db.dismiss_contact_notification(user_1, user_2)
1885                .await
1886                .unwrap_err();
1887            db.dismiss_contact_notification(user_2, user_1)
1888                .await
1889                .unwrap();
1890            assert_eq!(
1891                db.get_contacts(user_2).await.unwrap(),
1892                &[
1893                    Contact::Incoming {
1894                        user_id: user_1,
1895                        should_notify: false
1896                    },
1897                    Contact::Accepted {
1898                        user_id: user_2,
1899                        should_notify: false
1900                    },
1901                ]
1902            );
1903
1904            // User can't accept their own contact request
1905            db.respond_to_contact_request(user_1, user_2, true)
1906                .await
1907                .unwrap_err();
1908
1909            // User accepts a contact request. Both users see the contact.
1910            db.respond_to_contact_request(user_2, user_1, true)
1911                .await
1912                .unwrap();
1913            assert_eq!(
1914                db.get_contacts(user_1).await.unwrap(),
1915                &[
1916                    Contact::Accepted {
1917                        user_id: user_1,
1918                        should_notify: false
1919                    },
1920                    Contact::Accepted {
1921                        user_id: user_2,
1922                        should_notify: true
1923                    }
1924                ],
1925            );
1926            assert!(db.has_contact(user_1, user_2).await.unwrap());
1927            assert!(db.has_contact(user_2, user_1).await.unwrap());
1928            assert_eq!(
1929                db.get_contacts(user_2).await.unwrap(),
1930                &[
1931                    Contact::Accepted {
1932                        user_id: user_1,
1933                        should_notify: false,
1934                    },
1935                    Contact::Accepted {
1936                        user_id: user_2,
1937                        should_notify: false,
1938                    },
1939                ]
1940            );
1941
1942            // Users cannot re-request existing contacts.
1943            db.send_contact_request(user_1, user_2).await.unwrap_err();
1944            db.send_contact_request(user_2, user_1).await.unwrap_err();
1945
1946            // Users can't dismiss notifications of them accepting other users' requests.
1947            db.dismiss_contact_notification(user_2, user_1)
1948                .await
1949                .unwrap_err();
1950            assert_eq!(
1951                db.get_contacts(user_1).await.unwrap(),
1952                &[
1953                    Contact::Accepted {
1954                        user_id: user_1,
1955                        should_notify: false
1956                    },
1957                    Contact::Accepted {
1958                        user_id: user_2,
1959                        should_notify: true,
1960                    },
1961                ]
1962            );
1963
1964            // Users can dismiss notifications of other users accepting their requests.
1965            db.dismiss_contact_notification(user_1, user_2)
1966                .await
1967                .unwrap();
1968            assert_eq!(
1969                db.get_contacts(user_1).await.unwrap(),
1970                &[
1971                    Contact::Accepted {
1972                        user_id: user_1,
1973                        should_notify: false
1974                    },
1975                    Contact::Accepted {
1976                        user_id: user_2,
1977                        should_notify: false,
1978                    },
1979                ]
1980            );
1981
1982            // Users send each other concurrent contact requests and
1983            // see that they are immediately accepted.
1984            db.send_contact_request(user_1, user_3).await.unwrap();
1985            db.send_contact_request(user_3, user_1).await.unwrap();
1986            assert_eq!(
1987                db.get_contacts(user_1).await.unwrap(),
1988                &[
1989                    Contact::Accepted {
1990                        user_id: user_1,
1991                        should_notify: false
1992                    },
1993                    Contact::Accepted {
1994                        user_id: user_2,
1995                        should_notify: false,
1996                    },
1997                    Contact::Accepted {
1998                        user_id: user_3,
1999                        should_notify: false
2000                    },
2001                ]
2002            );
2003            assert_eq!(
2004                db.get_contacts(user_3).await.unwrap(),
2005                &[
2006                    Contact::Accepted {
2007                        user_id: user_1,
2008                        should_notify: false
2009                    },
2010                    Contact::Accepted {
2011                        user_id: user_3,
2012                        should_notify: false
2013                    }
2014                ],
2015            );
2016
2017            // User declines a contact request. Both users see that it is gone.
2018            db.send_contact_request(user_2, user_3).await.unwrap();
2019            db.respond_to_contact_request(user_3, user_2, false)
2020                .await
2021                .unwrap();
2022            assert!(!db.has_contact(user_2, user_3).await.unwrap());
2023            assert!(!db.has_contact(user_3, user_2).await.unwrap());
2024            assert_eq!(
2025                db.get_contacts(user_2).await.unwrap(),
2026                &[
2027                    Contact::Accepted {
2028                        user_id: user_1,
2029                        should_notify: false
2030                    },
2031                    Contact::Accepted {
2032                        user_id: user_2,
2033                        should_notify: false
2034                    }
2035                ]
2036            );
2037            assert_eq!(
2038                db.get_contacts(user_3).await.unwrap(),
2039                &[
2040                    Contact::Accepted {
2041                        user_id: user_1,
2042                        should_notify: false
2043                    },
2044                    Contact::Accepted {
2045                        user_id: user_3,
2046                        should_notify: false
2047                    }
2048                ],
2049            );
2050        }
2051    }
2052
2053    #[tokio::test(flavor = "multi_thread")]
2054    async fn test_invite_codes() {
2055        let postgres = TestDb::postgres().await;
2056        let db = postgres.db();
2057        let user1 = db.create_user("user-1", None, false).await.unwrap();
2058
2059        // Initially, user 1 has no invite code
2060        assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2061
2062        // Setting invite count to 0 when no code is assigned does not assign a new code
2063        db.set_invite_count(user1, 0).await.unwrap();
2064        assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2065
2066        // User 1 creates an invite code that can be used twice.
2067        db.set_invite_count(user1, 2).await.unwrap();
2068        let (invite_code, invite_count) =
2069            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2070        assert_eq!(invite_count, 2);
2071
2072        // User 2 redeems the invite code and becomes a contact of user 1.
2073        let user2 = db
2074            .redeem_invite_code(&invite_code, "user-2", None)
2075            .await
2076            .unwrap();
2077        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2078        assert_eq!(invite_count, 1);
2079        assert_eq!(
2080            db.get_contacts(user1).await.unwrap(),
2081            [
2082                Contact::Accepted {
2083                    user_id: user1,
2084                    should_notify: false
2085                },
2086                Contact::Accepted {
2087                    user_id: user2,
2088                    should_notify: true
2089                }
2090            ]
2091        );
2092        assert_eq!(
2093            db.get_contacts(user2).await.unwrap(),
2094            [
2095                Contact::Accepted {
2096                    user_id: user1,
2097                    should_notify: false
2098                },
2099                Contact::Accepted {
2100                    user_id: user2,
2101                    should_notify: false
2102                }
2103            ]
2104        );
2105
2106        // User 3 redeems the invite code and becomes a contact of user 1.
2107        let user3 = db
2108            .redeem_invite_code(&invite_code, "user-3", None)
2109            .await
2110            .unwrap();
2111        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2112        assert_eq!(invite_count, 0);
2113        assert_eq!(
2114            db.get_contacts(user1).await.unwrap(),
2115            [
2116                Contact::Accepted {
2117                    user_id: user1,
2118                    should_notify: false
2119                },
2120                Contact::Accepted {
2121                    user_id: user2,
2122                    should_notify: true
2123                },
2124                Contact::Accepted {
2125                    user_id: user3,
2126                    should_notify: true
2127                }
2128            ]
2129        );
2130        assert_eq!(
2131            db.get_contacts(user3).await.unwrap(),
2132            [
2133                Contact::Accepted {
2134                    user_id: user1,
2135                    should_notify: false
2136                },
2137                Contact::Accepted {
2138                    user_id: user3,
2139                    should_notify: false
2140                },
2141            ]
2142        );
2143
2144        // Trying to reedem the code for the third time results in an error.
2145        db.redeem_invite_code(&invite_code, "user-4", None)
2146            .await
2147            .unwrap_err();
2148
2149        // Invite count can be updated after the code has been created.
2150        db.set_invite_count(user1, 2).await.unwrap();
2151        let (latest_code, invite_count) =
2152            db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2153        assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2154        assert_eq!(invite_count, 2);
2155
2156        // User 4 can now redeem the invite code and becomes a contact of user 1.
2157        let user4 = db
2158            .redeem_invite_code(&invite_code, "user-4", None)
2159            .await
2160            .unwrap();
2161        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2162        assert_eq!(invite_count, 1);
2163        assert_eq!(
2164            db.get_contacts(user1).await.unwrap(),
2165            [
2166                Contact::Accepted {
2167                    user_id: user1,
2168                    should_notify: false
2169                },
2170                Contact::Accepted {
2171                    user_id: user2,
2172                    should_notify: true
2173                },
2174                Contact::Accepted {
2175                    user_id: user3,
2176                    should_notify: true
2177                },
2178                Contact::Accepted {
2179                    user_id: user4,
2180                    should_notify: true
2181                }
2182            ]
2183        );
2184        assert_eq!(
2185            db.get_contacts(user4).await.unwrap(),
2186            [
2187                Contact::Accepted {
2188                    user_id: user1,
2189                    should_notify: false
2190                },
2191                Contact::Accepted {
2192                    user_id: user4,
2193                    should_notify: false
2194                },
2195            ]
2196        );
2197
2198        // An existing user cannot redeem invite codes.
2199        db.redeem_invite_code(&invite_code, "user-2", None)
2200            .await
2201            .unwrap_err();
2202        let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2203        assert_eq!(invite_count, 1);
2204    }
2205
2206    pub struct TestDb {
2207        pub db: Option<Arc<dyn Db>>,
2208        pub url: String,
2209    }
2210
2211    impl TestDb {
2212        pub async fn postgres() -> Self {
2213            lazy_static! {
2214                static ref LOCK: Mutex<()> = Mutex::new(());
2215            }
2216
2217            let _guard = LOCK.lock();
2218            let mut rng = StdRng::from_entropy();
2219            let name = format!("zed-test-{}", rng.gen::<u128>());
2220            let url = format!("postgres://postgres@localhost/{}", name);
2221            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2222            Postgres::create_database(&url)
2223                .await
2224                .expect("failed to create test db");
2225            let db = PostgresDb::new(&url, 5).await.unwrap();
2226            let migrator = Migrator::new(migrations_path).await.unwrap();
2227            migrator.run(&db.pool).await.unwrap();
2228            Self {
2229                db: Some(Arc::new(db)),
2230                url,
2231            }
2232        }
2233
2234        pub fn fake(background: Arc<Background>) -> Self {
2235            Self {
2236                db: Some(Arc::new(FakeDb::new(background))),
2237                url: Default::default(),
2238            }
2239        }
2240
2241        pub fn db(&self) -> &Arc<dyn Db> {
2242            self.db.as_ref().unwrap()
2243        }
2244    }
2245
2246    impl Drop for TestDb {
2247        fn drop(&mut self) {
2248            if let Some(db) = self.db.take() {
2249                futures::executor::block_on(db.teardown(&self.url));
2250            }
2251        }
2252    }
2253
2254    pub struct FakeDb {
2255        background: Arc<Background>,
2256        pub users: Mutex<BTreeMap<UserId, User>>,
2257        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2258        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2259        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2260        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2261        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2262        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2263        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2264        pub contacts: Mutex<Vec<FakeContact>>,
2265        next_channel_message_id: Mutex<i32>,
2266        next_user_id: Mutex<i32>,
2267        next_org_id: Mutex<i32>,
2268        next_channel_id: Mutex<i32>,
2269        next_project_id: Mutex<i32>,
2270    }
2271
2272    #[derive(Debug)]
2273    pub struct FakeContact {
2274        pub requester_id: UserId,
2275        pub responder_id: UserId,
2276        pub accepted: bool,
2277        pub should_notify: bool,
2278    }
2279
2280    impl FakeDb {
2281        pub fn new(background: Arc<Background>) -> Self {
2282            Self {
2283                background,
2284                users: Default::default(),
2285                next_user_id: Mutex::new(0),
2286                projects: Default::default(),
2287                worktree_extensions: Default::default(),
2288                next_project_id: Mutex::new(1),
2289                orgs: Default::default(),
2290                next_org_id: Mutex::new(1),
2291                org_memberships: Default::default(),
2292                channels: Default::default(),
2293                next_channel_id: Mutex::new(1),
2294                channel_memberships: Default::default(),
2295                channel_messages: Default::default(),
2296                next_channel_message_id: Mutex::new(1),
2297                contacts: Default::default(),
2298            }
2299        }
2300    }
2301
2302    #[async_trait]
2303    impl Db for FakeDb {
2304        async fn create_user(
2305            &self,
2306            github_login: &str,
2307            email_address: Option<&str>,
2308            admin: bool,
2309        ) -> Result<UserId> {
2310            self.background.simulate_random_delay().await;
2311
2312            let mut users = self.users.lock();
2313            if let Some(user) = users
2314                .values()
2315                .find(|user| user.github_login == github_login)
2316            {
2317                Ok(user.id)
2318            } else {
2319                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2320                users.insert(
2321                    user_id,
2322                    User {
2323                        id: user_id,
2324                        github_login: github_login.to_string(),
2325                        email_address: email_address.map(str::to_string),
2326                        admin,
2327                        invite_code: None,
2328                        invite_count: 0,
2329                        connected_once: false,
2330                    },
2331                );
2332                Ok(user_id)
2333            }
2334        }
2335
2336        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2337            unimplemented!()
2338        }
2339
2340        async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2341            unimplemented!()
2342        }
2343
2344        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2345            unimplemented!()
2346        }
2347
2348        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2349            self.background.simulate_random_delay().await;
2350            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2351        }
2352
2353        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2354            self.background.simulate_random_delay().await;
2355            let users = self.users.lock();
2356            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2357        }
2358
2359        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2360            unimplemented!()
2361        }
2362
2363        async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2364            self.background.simulate_random_delay().await;
2365            Ok(self
2366                .users
2367                .lock()
2368                .values()
2369                .find(|user| user.github_login == github_login)
2370                .cloned())
2371        }
2372
2373        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2374            unimplemented!()
2375        }
2376
2377        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2378            self.background.simulate_random_delay().await;
2379            let mut users = self.users.lock();
2380            let mut user = users
2381                .get_mut(&id)
2382                .ok_or_else(|| anyhow!("user not found"))?;
2383            user.connected_once = connected_once;
2384            Ok(())
2385        }
2386
2387        async fn destroy_user(&self, _id: UserId) -> Result<()> {
2388            unimplemented!()
2389        }
2390
2391        // invite codes
2392
2393        async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2394            unimplemented!()
2395        }
2396
2397        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2398            self.background.simulate_random_delay().await;
2399            Ok(None)
2400        }
2401
2402        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2403            unimplemented!()
2404        }
2405
2406        async fn redeem_invite_code(
2407            &self,
2408            _code: &str,
2409            _login: &str,
2410            _email_address: Option<&str>,
2411        ) -> Result<UserId> {
2412            unimplemented!()
2413        }
2414
2415        // projects
2416
2417        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2418            self.background.simulate_random_delay().await;
2419            if !self.users.lock().contains_key(&host_user_id) {
2420                Err(anyhow!("no such user"))?;
2421            }
2422
2423            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2424            self.projects.lock().insert(
2425                project_id,
2426                Project {
2427                    id: project_id,
2428                    host_user_id,
2429                    unregistered: false,
2430                },
2431            );
2432            Ok(project_id)
2433        }
2434
2435        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2436            self.background.simulate_random_delay().await;
2437            self.projects
2438                .lock()
2439                .get_mut(&project_id)
2440                .ok_or_else(|| anyhow!("no such project"))?
2441                .unregistered = true;
2442            Ok(())
2443        }
2444
2445        async fn update_worktree_extensions(
2446            &self,
2447            project_id: ProjectId,
2448            worktree_id: u64,
2449            extensions: HashMap<String, u32>,
2450        ) -> Result<()> {
2451            self.background.simulate_random_delay().await;
2452            if !self.projects.lock().contains_key(&project_id) {
2453                Err(anyhow!("no such project"))?;
2454            }
2455
2456            for (extension, count) in extensions {
2457                self.worktree_extensions
2458                    .lock()
2459                    .insert((project_id, worktree_id, extension), count);
2460            }
2461
2462            Ok(())
2463        }
2464
2465        async fn get_project_extensions(
2466            &self,
2467            _project_id: ProjectId,
2468        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2469            unimplemented!()
2470        }
2471
2472        async fn record_user_activity(
2473            &self,
2474            _time_period: Range<OffsetDateTime>,
2475            _active_projects: &[(UserId, ProjectId)],
2476        ) -> Result<()> {
2477            unimplemented!()
2478        }
2479
2480        async fn get_top_users_activity_summary(
2481            &self,
2482            _time_period: Range<OffsetDateTime>,
2483            _limit: usize,
2484        ) -> Result<Vec<UserActivitySummary>> {
2485            unimplemented!()
2486        }
2487
2488        async fn get_user_activity_timeline(
2489            &self,
2490            _time_period: Range<OffsetDateTime>,
2491            _user_id: UserId,
2492        ) -> Result<Vec<UserActivityPeriod>> {
2493            unimplemented!()
2494        }
2495
2496        // contacts
2497
2498        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2499            self.background.simulate_random_delay().await;
2500            let mut contacts = vec![Contact::Accepted {
2501                user_id: id,
2502                should_notify: false,
2503            }];
2504
2505            for contact in self.contacts.lock().iter() {
2506                if contact.requester_id == id {
2507                    if contact.accepted {
2508                        contacts.push(Contact::Accepted {
2509                            user_id: contact.responder_id,
2510                            should_notify: contact.should_notify,
2511                        });
2512                    } else {
2513                        contacts.push(Contact::Outgoing {
2514                            user_id: contact.responder_id,
2515                        });
2516                    }
2517                } else if contact.responder_id == id {
2518                    if contact.accepted {
2519                        contacts.push(Contact::Accepted {
2520                            user_id: contact.requester_id,
2521                            should_notify: false,
2522                        });
2523                    } else {
2524                        contacts.push(Contact::Incoming {
2525                            user_id: contact.requester_id,
2526                            should_notify: contact.should_notify,
2527                        });
2528                    }
2529                }
2530            }
2531
2532            contacts.sort_unstable_by_key(|contact| contact.user_id());
2533            Ok(contacts)
2534        }
2535
2536        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2537            self.background.simulate_random_delay().await;
2538            Ok(self.contacts.lock().iter().any(|contact| {
2539                contact.accepted
2540                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2541                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2542            }))
2543        }
2544
2545        async fn send_contact_request(
2546            &self,
2547            requester_id: UserId,
2548            responder_id: UserId,
2549        ) -> Result<()> {
2550            self.background.simulate_random_delay().await;
2551            let mut contacts = self.contacts.lock();
2552            for contact in contacts.iter_mut() {
2553                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2554                    if contact.accepted {
2555                        Err(anyhow!("contact already exists"))?;
2556                    } else {
2557                        Err(anyhow!("contact already requested"))?;
2558                    }
2559                }
2560                if contact.responder_id == requester_id && contact.requester_id == responder_id {
2561                    if contact.accepted {
2562                        Err(anyhow!("contact already exists"))?;
2563                    } else {
2564                        contact.accepted = true;
2565                        contact.should_notify = false;
2566                        return Ok(());
2567                    }
2568                }
2569            }
2570            contacts.push(FakeContact {
2571                requester_id,
2572                responder_id,
2573                accepted: false,
2574                should_notify: true,
2575            });
2576            Ok(())
2577        }
2578
2579        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2580            self.background.simulate_random_delay().await;
2581            self.contacts.lock().retain(|contact| {
2582                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2583            });
2584            Ok(())
2585        }
2586
2587        async fn dismiss_contact_notification(
2588            &self,
2589            user_id: UserId,
2590            contact_user_id: UserId,
2591        ) -> Result<()> {
2592            self.background.simulate_random_delay().await;
2593            let mut contacts = self.contacts.lock();
2594            for contact in contacts.iter_mut() {
2595                if contact.requester_id == contact_user_id
2596                    && contact.responder_id == user_id
2597                    && !contact.accepted
2598                {
2599                    contact.should_notify = false;
2600                    return Ok(());
2601                }
2602                if contact.requester_id == user_id
2603                    && contact.responder_id == contact_user_id
2604                    && contact.accepted
2605                {
2606                    contact.should_notify = false;
2607                    return Ok(());
2608                }
2609            }
2610            Err(anyhow!("no such notification"))?
2611        }
2612
2613        async fn respond_to_contact_request(
2614            &self,
2615            responder_id: UserId,
2616            requester_id: UserId,
2617            accept: bool,
2618        ) -> Result<()> {
2619            self.background.simulate_random_delay().await;
2620            let mut contacts = self.contacts.lock();
2621            for (ix, contact) in contacts.iter_mut().enumerate() {
2622                if contact.requester_id == requester_id && contact.responder_id == responder_id {
2623                    if contact.accepted {
2624                        Err(anyhow!("contact already confirmed"))?;
2625                    }
2626                    if accept {
2627                        contact.accepted = true;
2628                        contact.should_notify = true;
2629                    } else {
2630                        contacts.remove(ix);
2631                    }
2632                    return Ok(());
2633                }
2634            }
2635            Err(anyhow!("no such contact request"))?
2636        }
2637
2638        async fn create_access_token_hash(
2639            &self,
2640            _user_id: UserId,
2641            _access_token_hash: &str,
2642            _max_access_token_count: usize,
2643        ) -> Result<()> {
2644            unimplemented!()
2645        }
2646
2647        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2648            unimplemented!()
2649        }
2650
2651        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2652            unimplemented!()
2653        }
2654
2655        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2656            self.background.simulate_random_delay().await;
2657            let mut orgs = self.orgs.lock();
2658            if orgs.values().any(|org| org.slug == slug) {
2659                Err(anyhow!("org already exists"))?
2660            } else {
2661                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2662                orgs.insert(
2663                    org_id,
2664                    Org {
2665                        id: org_id,
2666                        name: name.to_string(),
2667                        slug: slug.to_string(),
2668                    },
2669                );
2670                Ok(org_id)
2671            }
2672        }
2673
2674        async fn add_org_member(
2675            &self,
2676            org_id: OrgId,
2677            user_id: UserId,
2678            is_admin: bool,
2679        ) -> Result<()> {
2680            self.background.simulate_random_delay().await;
2681            if !self.orgs.lock().contains_key(&org_id) {
2682                Err(anyhow!("org does not exist"))?;
2683            }
2684            if !self.users.lock().contains_key(&user_id) {
2685                Err(anyhow!("user does not exist"))?;
2686            }
2687
2688            self.org_memberships
2689                .lock()
2690                .entry((org_id, user_id))
2691                .or_insert(is_admin);
2692            Ok(())
2693        }
2694
2695        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2696            self.background.simulate_random_delay().await;
2697            if !self.orgs.lock().contains_key(&org_id) {
2698                Err(anyhow!("org does not exist"))?;
2699            }
2700
2701            let mut channels = self.channels.lock();
2702            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2703            channels.insert(
2704                channel_id,
2705                Channel {
2706                    id: channel_id,
2707                    name: name.to_string(),
2708                    owner_id: org_id.0,
2709                    owner_is_user: false,
2710                },
2711            );
2712            Ok(channel_id)
2713        }
2714
2715        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2716            self.background.simulate_random_delay().await;
2717            Ok(self
2718                .channels
2719                .lock()
2720                .values()
2721                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2722                .cloned()
2723                .collect())
2724        }
2725
2726        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2727            self.background.simulate_random_delay().await;
2728            let channels = self.channels.lock();
2729            let memberships = self.channel_memberships.lock();
2730            Ok(channels
2731                .values()
2732                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2733                .cloned()
2734                .collect())
2735        }
2736
2737        async fn can_user_access_channel(
2738            &self,
2739            user_id: UserId,
2740            channel_id: ChannelId,
2741        ) -> Result<bool> {
2742            self.background.simulate_random_delay().await;
2743            Ok(self
2744                .channel_memberships
2745                .lock()
2746                .contains_key(&(channel_id, user_id)))
2747        }
2748
2749        async fn add_channel_member(
2750            &self,
2751            channel_id: ChannelId,
2752            user_id: UserId,
2753            is_admin: bool,
2754        ) -> Result<()> {
2755            self.background.simulate_random_delay().await;
2756            if !self.channels.lock().contains_key(&channel_id) {
2757                Err(anyhow!("channel does not exist"))?;
2758            }
2759            if !self.users.lock().contains_key(&user_id) {
2760                Err(anyhow!("user does not exist"))?;
2761            }
2762
2763            self.channel_memberships
2764                .lock()
2765                .entry((channel_id, user_id))
2766                .or_insert(is_admin);
2767            Ok(())
2768        }
2769
2770        async fn create_channel_message(
2771            &self,
2772            channel_id: ChannelId,
2773            sender_id: UserId,
2774            body: &str,
2775            timestamp: OffsetDateTime,
2776            nonce: u128,
2777        ) -> Result<MessageId> {
2778            self.background.simulate_random_delay().await;
2779            if !self.channels.lock().contains_key(&channel_id) {
2780                Err(anyhow!("channel does not exist"))?;
2781            }
2782            if !self.users.lock().contains_key(&sender_id) {
2783                Err(anyhow!("user does not exist"))?;
2784            }
2785
2786            let mut messages = self.channel_messages.lock();
2787            if let Some(message) = messages
2788                .values()
2789                .find(|message| message.nonce.as_u128() == nonce)
2790            {
2791                Ok(message.id)
2792            } else {
2793                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2794                messages.insert(
2795                    message_id,
2796                    ChannelMessage {
2797                        id: message_id,
2798                        channel_id,
2799                        sender_id,
2800                        body: body.to_string(),
2801                        sent_at: timestamp,
2802                        nonce: Uuid::from_u128(nonce),
2803                    },
2804                );
2805                Ok(message_id)
2806            }
2807        }
2808
2809        async fn get_channel_messages(
2810            &self,
2811            channel_id: ChannelId,
2812            count: usize,
2813            before_id: Option<MessageId>,
2814        ) -> Result<Vec<ChannelMessage>> {
2815            self.background.simulate_random_delay().await;
2816            let mut messages = self
2817                .channel_messages
2818                .lock()
2819                .values()
2820                .rev()
2821                .filter(|message| {
2822                    message.channel_id == channel_id
2823                        && message.id < before_id.unwrap_or(MessageId::MAX)
2824                })
2825                .take(count)
2826                .cloned()
2827                .collect::<Vec<_>>();
2828            messages.sort_unstable_by_key(|message| message.id);
2829            Ok(messages)
2830        }
2831
2832        async fn teardown(&self, _: &str) {}
2833
2834        #[cfg(test)]
2835        fn as_fake(&self) -> Option<&FakeDb> {
2836            Some(self)
2837        }
2838    }
2839}