Total WIP - try making Db a generic struct instead of a trait

Max Brunsfeld created

Change summary

crates/collab/src/auth.rs     |    2 
crates/collab/src/db.rs       | 1002 ++++--------------------------------
crates/collab/src/db_tests.rs |    2 
crates/collab/src/main.rs     |    8 
4 files changed, 118 insertions(+), 896 deletions(-)

Detailed changes

crates/collab/src/auth.rs 🔗

@@ -75,7 +75,7 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
 
 const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 
-pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> Result<String> {
+pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result<String> {
     let access_token = rpc::auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;

crates/collab/src/db.rs 🔗

@@ -1,6 +1,5 @@
 use crate::{Error, Result};
-use anyhow::{anyhow, Context};
-use async_trait::async_trait;
+use anyhow::anyhow;
 use axum::http::StatusCode;
 use collections::HashMap;
 use futures::StreamExt;
@@ -8,186 +7,20 @@ use serde::{Deserialize, Serialize};
 use sqlx::{
     migrate::{Migrate as _, Migration, MigrationSource},
     types::Uuid,
-    FromRow, QueryBuilder,
+    Encode, FromRow, QueryBuilder,
 };
 use std::{cmp, ops::Range, path::Path, time::Duration};
 use time::{OffsetDateTime, PrimitiveDateTime};
 
-#[async_trait]
-pub trait Db: Send + Sync {
-    async fn create_user(
-        &self,
-        email_address: &str,
-        admin: bool,
-        params: NewUserParams,
-    ) -> Result<NewUserResult>;
-    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
-    async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
-    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
-    async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
-    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
-    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
-    async fn get_user_by_github_account(
-        &self,
-        github_login: &str,
-        github_user_id: Option<i32>,
-    ) -> Result<Option<User>>;
-    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
-    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
-    async fn destroy_user(&self, id: UserId) -> Result<()>;
-
-    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
-    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
-    async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
-    async fn create_invite_from_code(
-        &self,
-        code: &str,
-        email_address: &str,
-        device_id: Option<&str>,
-    ) -> Result<Invite>;
-
-    async fn create_signup(&self, signup: Signup) -> Result<()>;
-    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
-    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
-    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
-    async fn create_user_from_invite(
-        &self,
-        invite: &Invite,
-        user: NewUserParams,
-    ) -> Result<Option<NewUserResult>>;
-
-    /// Registers a new project for the given user.
-    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
-
-    /// Unregisters a project for the given project id.
-    async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
-
-    /// Update file counts by extension for the given project and worktree.
-    async fn update_worktree_extensions(
-        &self,
-        project_id: ProjectId,
-        worktree_id: u64,
-        extensions: HashMap<String, u32>,
-    ) -> Result<()>;
-
-    /// Get the file counts on the given project keyed by their worktree and extension.
-    async fn get_project_extensions(
-        &self,
-        project_id: ProjectId,
-    ) -> Result<HashMap<u64, HashMap<String, usize>>>;
-
-    /// Record which users have been active in which projects during
-    /// a given period of time.
-    async fn record_user_activity(
-        &self,
-        time_period: Range<OffsetDateTime>,
-        active_projects: &[(UserId, ProjectId)],
-    ) -> Result<()>;
-
-    /// Get the number of users who have been active in the given
-    /// time period for at least the given time duration.
-    async fn get_active_user_count(
-        &self,
-        time_period: Range<OffsetDateTime>,
-        min_duration: Duration,
-        only_collaborative: bool,
-    ) -> Result<usize>;
-
-    /// Get the users that have been most active during the given time period,
-    /// along with the amount of time they have been active in each project.
-    async fn get_top_users_activity_summary(
-        &self,
-        time_period: Range<OffsetDateTime>,
-        max_user_count: usize,
-    ) -> Result<Vec<UserActivitySummary>>;
-
-    /// Get the project activity for the given user and time period.
-    async fn get_user_activity_timeline(
-        &self,
-        time_period: Range<OffsetDateTime>,
-        user_id: UserId,
-    ) -> Result<Vec<UserActivityPeriod>>;
-
-    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
-    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
-    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
-    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
-    async fn dismiss_contact_notification(
-        &self,
-        responder_id: UserId,
-        requester_id: UserId,
-    ) -> Result<()>;
-    async fn respond_to_contact_request(
-        &self,
-        responder_id: UserId,
-        requester_id: UserId,
-        accept: bool,
-    ) -> Result<()>;
-
-    async fn create_access_token_hash(
-        &self,
-        user_id: UserId,
-        access_token_hash: &str,
-        max_access_token_count: usize,
-    ) -> Result<()>;
-    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
-
-    #[cfg(any(test, feature = "seed-support"))]
-    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
-    #[cfg(any(test, feature = "seed-support"))]
-    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
-    #[cfg(any(test, feature = "seed-support"))]
-    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
-    #[cfg(any(test, feature = "seed-support"))]
-    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
-    #[cfg(any(test, feature = "seed-support"))]
-
-    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
-    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
-    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
-        -> Result<bool>;
-
-    #[cfg(any(test, feature = "seed-support"))]
-    async fn add_channel_member(
-        &self,
-        channel_id: ChannelId,
-        user_id: UserId,
-        is_admin: bool,
-    ) -> Result<()>;
-    async fn create_channel_message(
-        &self,
-        channel_id: ChannelId,
-        sender_id: UserId,
-        body: &str,
-        timestamp: OffsetDateTime,
-        nonce: u128,
-    ) -> Result<MessageId>;
-    async fn get_channel_messages(
-        &self,
-        channel_id: ChannelId,
-        count: usize,
-        before_id: Option<MessageId>,
-    ) -> Result<Vec<ChannelMessage>>;
-
-    #[cfg(test)]
-    async fn teardown(&self, url: &str);
-
-    #[cfg(test)]
-    fn as_fake(&self) -> Option<&FakeDb>;
-}
-
 #[cfg(any(test, debug_assertions))]
 pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
     Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
 
-pub const TEST_MIGRATIONS_PATH: Option<&'static str> =
-    Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"));
-
 #[cfg(not(any(test, debug_assertions)))]
 pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
 
-pub struct RealDb {
-    pool: sqlx::SqlitePool,
+pub struct Db<D: sqlx::Database> {
+    pool: sqlx::Pool<D>,
 }
 
 macro_rules! test_support {
@@ -204,16 +37,45 @@ macro_rules! test_support {
     }};
 }
 
-impl RealDb {
-    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
-        eprintln!("{url}");
+impl Db<sqlx::Sqlite> {
+    #[cfg(test)]
+    pub async fn sqlite(url: &str) -> Result<Self> {
         let pool = sqlx::sqlite::SqlitePoolOptions::new()
             .max_connections(1)
             .connect(url)
             .await?;
         Ok(Self { pool })
     }
+}
+
+impl Db<sqlx::Postgres> {
+    pub async fn postgres(url: &str, max_connection: u32) -> Result<Self> {
+        let pool = sqlx::postgres::PgPoolOptions::new()
+            .max_connections(1)
+            .connect(url)
+            .await?;
+        Ok(Self { pool })
+    }
+}
 
+impl<D> Db<D>
+where
+    D: sqlx::Database + sqlx::migrate::MigrateDatabase,
+    for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
+    D: for<'r> sqlx::database::HasValueRef<'r>,
+    D: for<'r> sqlx::database::HasArguments<'r>,
+    for<'a> &'a mut D::Connection: sqlx::Executor<'a>,
+    String: sqlx::Type<D>,
+    i32: sqlx::Type<D>,
+    bool: sqlx::Type<D>,
+    str: sqlx::Type<D>,
+    for<'a> str: sqlx::Encode<'a, D>,
+    for<'a> &'a str: sqlx::Encode<'a, D>,
+    for<'a> String: sqlx::Encode<'a, D>,
+    for<'a> i32: sqlx::Encode<'a, D>,
+    for<'a> bool: sqlx::Encode<'a, D>,
+    for<'a> Option<String>: sqlx::Encode<'a, D>,
+{
     pub async fn migrate(
         &self,
         migrations_path: &Path,
@@ -266,13 +128,10 @@ impl RealDb {
         result.push('%');
         result
     }
-}
 
-#[async_trait]
-impl Db for RealDb {
     // users
 
-    async fn create_user(
+    pub async fn create_user(
         &self,
         email_address: &str,
         admin: bool,
@@ -302,7 +161,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
+    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
         test_support!(self, {
             let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
             Ok(sqlx::query_as(query)
@@ -313,7 +172,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
+    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
         test_support!(self, {
             let like_string = Self::fuzzy_like_string(name_query);
             let query = "
@@ -332,7 +191,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
         test_support!(self, {
             let query = "
                 SELECT users.*
@@ -347,7 +206,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
+    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
         test_support!(self, {
             let query = "
                 SELECT metrics_id::text
@@ -361,7 +220,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
         test_support!(self, {
             let query = "
                 SELECT users.*
@@ -375,7 +234,10 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
+    pub async fn get_users_with_no_invites(
+        &self,
+        invited_by_another_user: bool,
+    ) -> Result<Vec<User>> {
         test_support!(self, {
             let query = format!(
                 "
@@ -391,7 +253,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_user_by_github_account(
+    pub async fn get_user_by_github_account(
         &self,
         github_login: &str,
         github_user_id: Option<i32>,
@@ -443,7 +305,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
         test_support!(self, {
             let query = "UPDATE users SET admin = $1 WHERE id = $2";
             Ok(sqlx::query(query)
@@ -455,7 +317,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
+    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
         test_support!(self, {
             let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
             Ok(sqlx::query(query)
@@ -467,7 +329,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn destroy_user(&self, id: UserId) -> Result<()> {
+    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
         test_support!(self, {
             let query = "DELETE FROM access_tokens WHERE user_id = $1;";
             sqlx::query(query)
@@ -486,7 +348,7 @@ impl Db for RealDb {
 
     // signups
 
-    async fn create_signup(&self, signup: Signup) -> Result<()> {
+    pub async fn create_signup(&self, signup: Signup) -> Result<()> {
         test_support!(self, {
             sqlx::query(
                 "
@@ -522,7 +384,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
+    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
         test_support!(self, {
             Ok(sqlx::query_as(
                 "
@@ -545,7 +407,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
+    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
         test_support!(self, {
             Ok(sqlx::query_as(
                 "
@@ -564,28 +426,28 @@ impl Db for RealDb {
         })
     }
 
-    async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
+    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
         test_support!(self, {
-            // sqlx::query(
-            //     "
-            //     UPDATE signups
-            //     SET email_confirmation_sent = TRUE
-            //     WHERE email_address = ANY ($1)
-            //     ",
-            // )
+            sqlx::query(
+                "
+                UPDATE signups
+                SET email_confirmation_sent = TRUE
+                WHERE email_address = ANY ($1)
+                ",
+            )
             // .bind(
             //     &invites
             //         .iter()
             //         .map(|s| s.email_address.as_str())
             //         .collect::<Vec<_>>(),
             // )
-            // .execute(&self.pool)
-            // .await?;
+            .execute(&self.pool)
+            .await?;
             Ok(())
         })
     }
 
-    async fn create_user_from_invite(
+    pub async fn create_user_from_invite(
         &self,
         invite: &Invite,
         user: NewUserParams,
@@ -697,7 +559,7 @@ impl Db for RealDb {
 
     // invite codes
 
-    async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
+    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
         test_support!(self, {
             let mut tx = self.pool.begin().await?;
             if count > 0 {
@@ -730,7 +592,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
+    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
         test_support!(self, {
             let result: Option<(String, i32)> = sqlx::query_as(
                 "
@@ -750,7 +612,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
+    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
         test_support!(self, {
             sqlx::query_as(
                 "
@@ -771,7 +633,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn create_invite_from_code(
+    pub async fn create_invite_from_code(
         &self,
         code: &str,
         email_address: &str,
@@ -860,7 +722,8 @@ impl Db for RealDb {
 
     // projects
 
-    async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
+    /// Registers a new project for the given user.
+    pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
         test_support!(self, {
             Ok(sqlx::query_scalar(
                 "
@@ -876,7 +739,8 @@ impl Db for RealDb {
         })
     }
 
-    async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
+    /// Unregisters a project for the given project id.
+    pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
         test_support!(self, {
             sqlx::query(
                 "
@@ -892,7 +756,8 @@ impl Db for RealDb {
         })
     }
 
-    async fn update_worktree_extensions(
+    /// Update file counts by extension for the given project and worktree.
+    pub async fn update_worktree_extensions(
         &self,
         project_id: ProjectId,
         worktree_id: u64,
@@ -925,7 +790,8 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_project_extensions(
+    /// Get the file counts on the given project keyed by their worktree and extension.
+    pub async fn get_project_extensions(
         &self,
         project_id: ProjectId,
     ) -> Result<HashMap<u64, HashMap<String, usize>>> {
@@ -958,7 +824,9 @@ impl Db for RealDb {
         })
     }
 
-    async fn record_user_activity(
+    /// Record which users have been active in which projects during
+    /// a given period of time.
+    pub async fn record_user_activity(
         &self,
         time_period: Range<OffsetDateTime>,
         projects: &[(UserId, ProjectId)],
@@ -989,7 +857,9 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_active_user_count(
+    /// Get the number of users who have been active in the given
+    /// time period for at least the given time duration.
+    pub async fn get_active_user_count(
         &self,
         time_period: Range<OffsetDateTime>,
         min_duration: Duration,
@@ -1066,7 +936,9 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_top_users_activity_summary(
+    /// Get the users that have been most active during the given time period,
+    /// along with the amount of time they have been active in each project.
+    pub async fn get_top_users_activity_summary(
         &self,
         time_period: Range<OffsetDateTime>,
         max_user_count: usize,
@@ -1135,7 +1007,8 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_user_activity_timeline(
+    /// Get the project activity for the given user and time period.
+    pub async fn get_user_activity_timeline(
         &self,
         time_period: Range<OffsetDateTime>,
         user_id: UserId,
@@ -1224,7 +1097,7 @@ impl Db for RealDb {
 
     // contacts
 
-    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
+    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
         test_support!(self, {
             let query = "
                 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
@@ -1275,7 +1148,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
+    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
         test_support!(self, {
             let (id_a, id_b) = if user_id_1 < user_id_2 {
                 (user_id_1, user_id_2)
@@ -1297,7 +1170,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
+    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
         test_support!(self, {
             let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
                 (sender_id, receiver_id, true)
@@ -1331,7 +1204,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
+    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
         test_support!(self, {
             let (id_a, id_b) = if responder_id < requester_id {
                 (responder_id, requester_id)
@@ -1356,7 +1229,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn dismiss_contact_notification(
+    pub async fn dismiss_contact_notification(
         &self,
         user_id: UserId,
         contact_user_id: UserId,
@@ -1394,7 +1267,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn respond_to_contact_request(
+    pub async fn respond_to_contact_request(
         &self,
         responder_id: UserId,
         requester_id: UserId,
@@ -1440,7 +1313,7 @@ impl Db for RealDb {
 
     // access tokens
 
-    async fn create_access_token_hash(
+    pub async fn create_access_token_hash(
         &self,
         user_id: UserId,
         access_token_hash: &str,
@@ -1477,7 +1350,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
         test_support!(self, {
             let query = "
                 SELECT hash
@@ -1496,7 +1369,7 @@ impl Db for RealDb {
 
     #[allow(unused)] // Help rust-analyzer
     #[cfg(any(test, feature = "seed-support"))]
-    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
+    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
         test_support!(self, {
             let query = "
                 SELECT *
@@ -1511,7 +1384,7 @@ impl Db for RealDb {
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
         test_support!(self, {
             let query = "
                 INSERT INTO orgs (name, slug)
@@ -1528,7 +1401,12 @@ impl Db for RealDb {
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
+    pub async fn add_org_member(
+        &self,
+        org_id: OrgId,
+        user_id: UserId,
+        is_admin: bool,
+    ) -> Result<()> {
         test_support!(self, {
             let query = "
                 INSERT INTO org_memberships (org_id, user_id, admin)
@@ -1548,7 +1426,7 @@ impl Db for RealDb {
     // channels
 
     #[cfg(any(test, feature = "seed-support"))]
-    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
         test_support!(self, {
             let query = "
                 INSERT INTO channels (owner_id, owner_is_user, name)
@@ -1566,7 +1444,7 @@ impl Db for RealDb {
 
     #[allow(unused)] // Help rust-analyzer
     #[cfg(any(test, feature = "seed-support"))]
-    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
         test_support!(self, {
             let query = "
                 SELECT *
@@ -1582,7 +1460,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
         test_support!(self, {
             let query = "
                 SELECT
@@ -1600,7 +1478,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn can_user_access_channel(
+    pub async fn can_user_access_channel(
         &self,
         user_id: UserId,
         channel_id: ChannelId,
@@ -1622,7 +1500,7 @@ impl Db for RealDb {
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    async fn add_channel_member(
+    pub async fn add_channel_member(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
@@ -1646,7 +1524,7 @@ impl Db for RealDb {
 
     // messages
 
-    async fn create_channel_message(
+    pub async fn create_channel_message(
         &self,
         channel_id: ChannelId,
         sender_id: UserId,
@@ -1673,7 +1551,7 @@ impl Db for RealDb {
         })
     }
 
-    async fn get_channel_messages(
+    pub async fn get_channel_messages(
         &self,
         channel_id: ChannelId,
         count: usize,
@@ -1704,9 +1582,7 @@ impl Db for RealDb {
     }
 
     #[cfg(test)]
-    async fn teardown(&self, url: &str) {
-        let start = std::time::Instant::now();
-        eprintln!("tearing down database...");
+    pub async fn teardown(&self, url: &str) {
         test_support!(self, {
             use util::ResultExt;
 
@@ -1720,14 +1596,8 @@ impl Db for RealDb {
             <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
                 .await
                 .log_err();
-            eprintln!("tore down database: {:?}", start.elapsed());
         })
     }
-
-    #[cfg(test)]
-    fn as_fake(&self) -> Option<&FakeDb> {
-        None
-    }
 }
 
 macro_rules! id_type {
@@ -1937,661 +1807,13 @@ pub use test::*;
 #[cfg(test)]
 mod test {
     use super::*;
-    use anyhow::anyhow;
-    use collections::BTreeMap;
     use gpui::executor::Background;
-    use parking_lot::Mutex;
     use rand::prelude::*;
     use sqlx::{migrate::MigrateDatabase, Sqlite};
     use std::sync::Arc;
-    use util::post_inc;
-
-    pub struct FakeDb {
-        background: Arc<Background>,
-        pub users: Mutex<BTreeMap<UserId, User>>,
-        pub projects: Mutex<BTreeMap<ProjectId, Project>>,
-        pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
-        pub orgs: Mutex<BTreeMap<OrgId, Org>>,
-        pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
-        pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
-        pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
-        pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
-        pub contacts: Mutex<Vec<FakeContact>>,
-        next_channel_message_id: Mutex<i32>,
-        next_user_id: Mutex<i32>,
-        next_org_id: Mutex<i32>,
-        next_channel_id: Mutex<i32>,
-        next_project_id: Mutex<i32>,
-    }
-
-    #[derive(Debug)]
-    pub struct FakeContact {
-        pub requester_id: UserId,
-        pub responder_id: UserId,
-        pub accepted: bool,
-        pub should_notify: bool,
-    }
-
-    impl FakeDb {
-        pub fn new(background: Arc<Background>) -> Self {
-            Self {
-                background,
-                users: Default::default(),
-                next_user_id: Mutex::new(0),
-                projects: Default::default(),
-                worktree_extensions: Default::default(),
-                next_project_id: Mutex::new(1),
-                orgs: Default::default(),
-                next_org_id: Mutex::new(1),
-                org_memberships: Default::default(),
-                channels: Default::default(),
-                next_channel_id: Mutex::new(1),
-                channel_memberships: Default::default(),
-                channel_messages: Default::default(),
-                next_channel_message_id: Mutex::new(1),
-                contacts: Default::default(),
-            }
-        }
-    }
-
-    #[async_trait]
-    impl Db for FakeDb {
-        async fn create_user(
-            &self,
-            email_address: &str,
-            admin: bool,
-            params: NewUserParams,
-        ) -> Result<NewUserResult> {
-            self.background.simulate_random_delay().await;
-
-            let mut users = self.users.lock();
-            let user_id = if let Some(user) = users
-                .values()
-                .find(|user| user.github_login == params.github_login)
-            {
-                user.id
-            } else {
-                let id = post_inc(&mut *self.next_user_id.lock());
-                let user_id = UserId(id);
-                users.insert(
-                    user_id,
-                    User {
-                        id: user_id,
-                        github_login: params.github_login,
-                        github_user_id: Some(params.github_user_id),
-                        email_address: Some(email_address.to_string()),
-                        admin,
-                        invite_code: None,
-                        invite_count: 0,
-                        connected_once: false,
-                    },
-                );
-                user_id
-            };
-            Ok(NewUserResult {
-                user_id,
-                metrics_id: "the-metrics-id".to_string(),
-                inviting_user_id: None,
-                signup_device_id: None,
-            })
-        }
-
-        async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
-            unimplemented!()
-        }
-
-        async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
-            unimplemented!()
-        }
-
-        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
-            self.background.simulate_random_delay().await;
-            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
-        }
-
-        async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
-            Ok("the-metrics-id".to_string())
-        }
-
-        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
-            self.background.simulate_random_delay().await;
-            let users = self.users.lock();
-            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
-        }
-
-        async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
-            unimplemented!()
-        }
-
-        async fn get_user_by_github_account(
-            &self,
-            github_login: &str,
-            github_user_id: Option<i32>,
-        ) -> Result<Option<User>> {
-            self.background.simulate_random_delay().await;
-            if let Some(github_user_id) = github_user_id {
-                for user in self.users.lock().values_mut() {
-                    if user.github_user_id == Some(github_user_id) {
-                        user.github_login = github_login.into();
-                        return Ok(Some(user.clone()));
-                    }
-                    if user.github_login == github_login {
-                        user.github_user_id = Some(github_user_id);
-                        return Ok(Some(user.clone()));
-                    }
-                }
-                Ok(None)
-            } else {
-                Ok(self
-                    .users
-                    .lock()
-                    .values()
-                    .find(|user| user.github_login == github_login)
-                    .cloned())
-            }
-        }
-
-        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            let mut users = self.users.lock();
-            let mut user = users
-                .get_mut(&id)
-                .ok_or_else(|| anyhow!("user not found"))?;
-            user.connected_once = connected_once;
-            Ok(())
-        }
-
-        async fn destroy_user(&self, _id: UserId) -> Result<()> {
-            unimplemented!()
-        }
-
-        // signups
-
-        async fn create_signup(&self, _signup: Signup) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
-            unimplemented!()
-        }
-
-        async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
-            unimplemented!()
-        }
-
-        async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn create_user_from_invite(
-            &self,
-            _invite: &Invite,
-            _user: NewUserParams,
-        ) -> Result<Option<NewUserResult>> {
-            unimplemented!()
-        }
-
-        // invite codes
-
-        async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
-            self.background.simulate_random_delay().await;
-            Ok(None)
-        }
-
-        async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
-            unimplemented!()
-        }
-
-        async fn create_invite_from_code(
-            &self,
-            _code: &str,
-            _email_address: &str,
-            _device_id: Option<&str>,
-        ) -> Result<Invite> {
-            unimplemented!()
-        }
-
-        // projects
-
-        async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
-            self.background.simulate_random_delay().await;
-            if !self.users.lock().contains_key(&host_user_id) {
-                Err(anyhow!("no such user"))?;
-            }
-
-            let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
-            self.projects.lock().insert(
-                project_id,
-                Project {
-                    id: project_id,
-                    host_user_id,
-                    unregistered: false,
-                },
-            );
-            Ok(project_id)
-        }
-
-        async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            self.projects
-                .lock()
-                .get_mut(&project_id)
-                .ok_or_else(|| anyhow!("no such project"))?
-                .unregistered = true;
-            Ok(())
-        }
-
-        async fn update_worktree_extensions(
-            &self,
-            project_id: ProjectId,
-            worktree_id: u64,
-            extensions: HashMap<String, u32>,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            if !self.projects.lock().contains_key(&project_id) {
-                Err(anyhow!("no such project"))?;
-            }
-
-            for (extension, count) in extensions {
-                self.worktree_extensions
-                    .lock()
-                    .insert((project_id, worktree_id, extension), count);
-            }
-
-            Ok(())
-        }
-
-        async fn get_project_extensions(
-            &self,
-            _project_id: ProjectId,
-        ) -> Result<HashMap<u64, HashMap<String, usize>>> {
-            unimplemented!()
-        }
-
-        async fn record_user_activity(
-            &self,
-            _time_period: Range<OffsetDateTime>,
-            _active_projects: &[(UserId, ProjectId)],
-        ) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn get_active_user_count(
-            &self,
-            _time_period: Range<OffsetDateTime>,
-            _min_duration: Duration,
-            _only_collaborative: bool,
-        ) -> Result<usize> {
-            unimplemented!()
-        }
-
-        async fn get_top_users_activity_summary(
-            &self,
-            _time_period: Range<OffsetDateTime>,
-            _limit: usize,
-        ) -> Result<Vec<UserActivitySummary>> {
-            unimplemented!()
-        }
-
-        async fn get_user_activity_timeline(
-            &self,
-            _time_period: Range<OffsetDateTime>,
-            _user_id: UserId,
-        ) -> Result<Vec<UserActivityPeriod>> {
-            unimplemented!()
-        }
-
-        // contacts
-
-        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
-            self.background.simulate_random_delay().await;
-            let mut contacts = Vec::new();
-
-            for contact in self.contacts.lock().iter() {
-                if contact.requester_id == id {
-                    if contact.accepted {
-                        contacts.push(Contact::Accepted {
-                            user_id: contact.responder_id,
-                            should_notify: contact.should_notify,
-                        });
-                    } else {
-                        contacts.push(Contact::Outgoing {
-                            user_id: contact.responder_id,
-                        });
-                    }
-                } else if contact.responder_id == id {
-                    if contact.accepted {
-                        contacts.push(Contact::Accepted {
-                            user_id: contact.requester_id,
-                            should_notify: false,
-                        });
-                    } else {
-                        contacts.push(Contact::Incoming {
-                            user_id: contact.requester_id,
-                            should_notify: contact.should_notify,
-                        });
-                    }
-                }
-            }
-
-            contacts.sort_unstable_by_key(|contact| contact.user_id());
-            Ok(contacts)
-        }
-
-        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
-            self.background.simulate_random_delay().await;
-            Ok(self.contacts.lock().iter().any(|contact| {
-                contact.accepted
-                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
-                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
-            }))
-        }
-
-        async fn send_contact_request(
-            &self,
-            requester_id: UserId,
-            responder_id: UserId,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            let mut contacts = self.contacts.lock();
-            for contact in contacts.iter_mut() {
-                if contact.requester_id == requester_id && contact.responder_id == responder_id {
-                    if contact.accepted {
-                        Err(anyhow!("contact already exists"))?;
-                    } else {
-                        Err(anyhow!("contact already requested"))?;
-                    }
-                }
-                if contact.responder_id == requester_id && contact.requester_id == responder_id {
-                    if contact.accepted {
-                        Err(anyhow!("contact already exists"))?;
-                    } else {
-                        contact.accepted = true;
-                        contact.should_notify = false;
-                        return Ok(());
-                    }
-                }
-            }
-            contacts.push(FakeContact {
-                requester_id,
-                responder_id,
-                accepted: false,
-                should_notify: true,
-            });
-            Ok(())
-        }
-
-        async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            self.contacts.lock().retain(|contact| {
-                !(contact.requester_id == requester_id && contact.responder_id == responder_id)
-            });
-            Ok(())
-        }
-
-        async fn dismiss_contact_notification(
-            &self,
-            user_id: UserId,
-            contact_user_id: UserId,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            let mut contacts = self.contacts.lock();
-            for contact in contacts.iter_mut() {
-                if contact.requester_id == contact_user_id
-                    && contact.responder_id == user_id
-                    && !contact.accepted
-                {
-                    contact.should_notify = false;
-                    return Ok(());
-                }
-                if contact.requester_id == user_id
-                    && contact.responder_id == contact_user_id
-                    && contact.accepted
-                {
-                    contact.should_notify = false;
-                    return Ok(());
-                }
-            }
-            Err(anyhow!("no such notification"))?
-        }
-
-        async fn respond_to_contact_request(
-            &self,
-            responder_id: UserId,
-            requester_id: UserId,
-            accept: bool,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            let mut contacts = self.contacts.lock();
-            for (ix, contact) in contacts.iter_mut().enumerate() {
-                if contact.requester_id == requester_id && contact.responder_id == responder_id {
-                    if contact.accepted {
-                        Err(anyhow!("contact already confirmed"))?;
-                    }
-                    if accept {
-                        contact.accepted = true;
-                        contact.should_notify = true;
-                    } else {
-                        contacts.remove(ix);
-                    }
-                    return Ok(());
-                }
-            }
-            Err(anyhow!("no such contact request"))?
-        }
-
-        async fn create_access_token_hash(
-            &self,
-            _user_id: UserId,
-            _access_token_hash: &str,
-            _max_access_token_count: usize,
-        ) -> Result<()> {
-            unimplemented!()
-        }
-
-        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
-            unimplemented!()
-        }
-
-        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
-            unimplemented!()
-        }
-
-        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
-            self.background.simulate_random_delay().await;
-            let mut orgs = self.orgs.lock();
-            if orgs.values().any(|org| org.slug == slug) {
-                Err(anyhow!("org already exists"))?
-            } else {
-                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
-                orgs.insert(
-                    org_id,
-                    Org {
-                        id: org_id,
-                        name: name.to_string(),
-                        slug: slug.to_string(),
-                    },
-                );
-                Ok(org_id)
-            }
-        }
-
-        async fn add_org_member(
-            &self,
-            org_id: OrgId,
-            user_id: UserId,
-            is_admin: bool,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            if !self.orgs.lock().contains_key(&org_id) {
-                Err(anyhow!("org does not exist"))?;
-            }
-            if !self.users.lock().contains_key(&user_id) {
-                Err(anyhow!("user does not exist"))?;
-            }
-
-            self.org_memberships
-                .lock()
-                .entry((org_id, user_id))
-                .or_insert(is_admin);
-            Ok(())
-        }
-
-        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
-            self.background.simulate_random_delay().await;
-            if !self.orgs.lock().contains_key(&org_id) {
-                Err(anyhow!("org does not exist"))?;
-            }
-
-            let mut channels = self.channels.lock();
-            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
-            channels.insert(
-                channel_id,
-                Channel {
-                    id: channel_id,
-                    name: name.to_string(),
-                    owner_id: org_id.0,
-                    owner_is_user: false,
-                },
-            );
-            Ok(channel_id)
-        }
-
-        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
-            self.background.simulate_random_delay().await;
-            Ok(self
-                .channels
-                .lock()
-                .values()
-                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
-                .cloned()
-                .collect())
-        }
-
-        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
-            self.background.simulate_random_delay().await;
-            let channels = self.channels.lock();
-            let memberships = self.channel_memberships.lock();
-            Ok(channels
-                .values()
-                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
-                .cloned()
-                .collect())
-        }
-
-        async fn can_user_access_channel(
-            &self,
-            user_id: UserId,
-            channel_id: ChannelId,
-        ) -> Result<bool> {
-            self.background.simulate_random_delay().await;
-            Ok(self
-                .channel_memberships
-                .lock()
-                .contains_key(&(channel_id, user_id)))
-        }
-
-        async fn add_channel_member(
-            &self,
-            channel_id: ChannelId,
-            user_id: UserId,
-            is_admin: bool,
-        ) -> Result<()> {
-            self.background.simulate_random_delay().await;
-            if !self.channels.lock().contains_key(&channel_id) {
-                Err(anyhow!("channel does not exist"))?;
-            }
-            if !self.users.lock().contains_key(&user_id) {
-                Err(anyhow!("user does not exist"))?;
-            }
-
-            self.channel_memberships
-                .lock()
-                .entry((channel_id, user_id))
-                .or_insert(is_admin);
-            Ok(())
-        }
-
-        async fn create_channel_message(
-            &self,
-            channel_id: ChannelId,
-            sender_id: UserId,
-            body: &str,
-            timestamp: OffsetDateTime,
-            nonce: u128,
-        ) -> Result<MessageId> {
-            self.background.simulate_random_delay().await;
-            if !self.channels.lock().contains_key(&channel_id) {
-                Err(anyhow!("channel does not exist"))?;
-            }
-            if !self.users.lock().contains_key(&sender_id) {
-                Err(anyhow!("user does not exist"))?;
-            }
-
-            let mut messages = self.channel_messages.lock();
-            if let Some(message) = messages
-                .values()
-                .find(|message| message.nonce.as_u128() == nonce)
-            {
-                Ok(message.id)
-            } else {
-                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
-                messages.insert(
-                    message_id,
-                    ChannelMessage {
-                        id: message_id,
-                        channel_id,
-                        sender_id,
-                        body: body.to_string(),
-                        sent_at: timestamp,
-                        nonce: Uuid::from_u128(nonce),
-                    },
-                );
-                Ok(message_id)
-            }
-        }
-
-        async fn get_channel_messages(
-            &self,
-            channel_id: ChannelId,
-            count: usize,
-            before_id: Option<MessageId>,
-        ) -> Result<Vec<ChannelMessage>> {
-            self.background.simulate_random_delay().await;
-            let mut messages = self
-                .channel_messages
-                .lock()
-                .values()
-                .rev()
-                .filter(|message| {
-                    message.channel_id == channel_id
-                        && message.id < before_id.unwrap_or(MessageId::MAX)
-                })
-                .take(count)
-                .cloned()
-                .collect::<Vec<_>>();
-            messages.sort_unstable_by_key(|message| message.id);
-            Ok(messages)
-        }
-
-        async fn teardown(&self, _: &str) {}
-
-        #[cfg(test)]
-        fn as_fake(&self) -> Option<&FakeDb> {
-            Some(self)
-        }
-    }
 
     pub struct TestDb {
-        pub db: Option<Arc<dyn Db>>,
+        pub db: Option<Arc<Db>>,
         pub url: String,
     }
 

crates/collab/src/db_tests.rs 🔗

@@ -625,7 +625,7 @@ async fn test_fuzzy_search_users() {
         &["rhode-island", "colorado", "oregon"],
     );
 
-    async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
+    async fn fuzzy_search_user_names(db: &Arc<TestDb>, query: &str) -> Vec<String> {
         db.fuzzy_search_users(query, 10)
             .await
             .unwrap()

crates/collab/src/main.rs 🔗

@@ -13,7 +13,7 @@ use crate::rpc::ResultExt as _;
 use anyhow::anyhow;
 use axum::{routing::get, Router};
 use collab::{Error, Result};
-use db::{Db, RealDb};
+use db::DefaultDb as Db;
 use serde::Deserialize;
 use std::{
     env::args,
@@ -49,14 +49,14 @@ pub struct MigrateConfig {
 }
 
 pub struct AppState {
-    db: Arc<dyn Db>,
+    db: Arc<Db>,
     live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
     config: Config,
 }
 
 impl AppState {
     async fn new(config: Config) -> Result<Arc<Self>> {
-        let db = RealDb::new(&config.database_url, 5).await?;
+        let db = Db::new(&config.database_url, 5).await?;
         let live_kit_client = if let Some(((server, key), secret)) = config
             .live_kit_server
             .as_ref()
@@ -96,7 +96,7 @@ async fn main() -> Result<()> {
         }
         Some("migrate") => {
             let config = envy::from_env::<MigrateConfig>().expect("error loading config");
-            let db = RealDb::new(&config.database_url, 5).await?;
+            let db = Db::new(&config.database_url, 5).await?;
 
             let migrations_path = config
                 .migrations_path