@@ -1,6 +1,5 @@
use crate::{Error, Result};
-use anyhow::{anyhow, Context};
-use async_trait::async_trait;
+use anyhow::anyhow;
use axum::http::StatusCode;
use collections::HashMap;
use futures::StreamExt;
@@ -8,186 +7,20 @@ use serde::{Deserialize, Serialize};
use sqlx::{
migrate::{Migrate as _, Migration, MigrationSource},
types::Uuid,
- FromRow, QueryBuilder,
+ Encode, FromRow, QueryBuilder,
};
use std::{cmp, ops::Range, path::Path, time::Duration};
use time::{OffsetDateTime, PrimitiveDateTime};
-#[async_trait]
-pub trait Db: Send + Sync {
- async fn create_user(
- &self,
- email_address: &str,
- admin: bool,
- params: NewUserParams,
- ) -> Result<NewUserResult>;
- async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
- async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
- async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
- async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
- async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
- async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
- async fn get_user_by_github_account(
- &self,
- github_login: &str,
- github_user_id: Option<i32>,
- ) -> Result<Option<User>>;
- async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
- async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
- async fn destroy_user(&self, id: UserId) -> Result<()>;
-
- async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
- async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
- async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
- async fn create_invite_from_code(
- &self,
- code: &str,
- email_address: &str,
- device_id: Option<&str>,
- ) -> Result<Invite>;
-
- async fn create_signup(&self, signup: Signup) -> Result<()>;
- async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
- async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
- async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
- async fn create_user_from_invite(
- &self,
- invite: &Invite,
- user: NewUserParams,
- ) -> Result<Option<NewUserResult>>;
-
- /// Registers a new project for the given user.
- async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
-
- /// Unregisters a project for the given project id.
- async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
-
- /// Update file counts by extension for the given project and worktree.
- async fn update_worktree_extensions(
- &self,
- project_id: ProjectId,
- worktree_id: u64,
- extensions: HashMap<String, u32>,
- ) -> Result<()>;
-
- /// Get the file counts on the given project keyed by their worktree and extension.
- async fn get_project_extensions(
- &self,
- project_id: ProjectId,
- ) -> Result<HashMap<u64, HashMap<String, usize>>>;
-
- /// Record which users have been active in which projects during
- /// a given period of time.
- async fn record_user_activity(
- &self,
- time_period: Range<OffsetDateTime>,
- active_projects: &[(UserId, ProjectId)],
- ) -> Result<()>;
-
- /// Get the number of users who have been active in the given
- /// time period for at least the given time duration.
- async fn get_active_user_count(
- &self,
- time_period: Range<OffsetDateTime>,
- min_duration: Duration,
- only_collaborative: bool,
- ) -> Result<usize>;
-
- /// Get the users that have been most active during the given time period,
- /// along with the amount of time they have been active in each project.
- async fn get_top_users_activity_summary(
- &self,
- time_period: Range<OffsetDateTime>,
- max_user_count: usize,
- ) -> Result<Vec<UserActivitySummary>>;
-
- /// Get the project activity for the given user and time period.
- async fn get_user_activity_timeline(
- &self,
- time_period: Range<OffsetDateTime>,
- user_id: UserId,
- ) -> Result<Vec<UserActivityPeriod>>;
-
- async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
- async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
- async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
- async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
- async fn dismiss_contact_notification(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- ) -> Result<()>;
- async fn respond_to_contact_request(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- accept: bool,
- ) -> Result<()>;
-
- async fn create_access_token_hash(
- &self,
- user_id: UserId,
- access_token_hash: &str,
- max_access_token_count: usize,
- ) -> Result<()>;
- async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
-
- #[cfg(any(test, feature = "seed-support"))]
- async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
- #[cfg(any(test, feature = "seed-support"))]
- async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
- #[cfg(any(test, feature = "seed-support"))]
-
- async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
- async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
- async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
- -> Result<bool>;
-
- #[cfg(any(test, feature = "seed-support"))]
- async fn add_channel_member(
- &self,
- channel_id: ChannelId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()>;
- async fn create_channel_message(
- &self,
- channel_id: ChannelId,
- sender_id: UserId,
- body: &str,
- timestamp: OffsetDateTime,
- nonce: u128,
- ) -> Result<MessageId>;
- async fn get_channel_messages(
- &self,
- channel_id: ChannelId,
- count: usize,
- before_id: Option<MessageId>,
- ) -> Result<Vec<ChannelMessage>>;
-
- #[cfg(test)]
- async fn teardown(&self, url: &str);
-
- #[cfg(test)]
- fn as_fake(&self) -> Option<&FakeDb>;
-}
-
#[cfg(any(test, debug_assertions))]
pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
-pub const TEST_MIGRATIONS_PATH: Option<&'static str> =
- Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"));
-
#[cfg(not(any(test, debug_assertions)))]
pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
-pub struct RealDb {
- pool: sqlx::SqlitePool,
+pub struct Db<D: sqlx::Database> {
+ pool: sqlx::Pool<D>,
}
macro_rules! test_support {
@@ -204,16 +37,45 @@ macro_rules! test_support {
}};
}
-impl RealDb {
- pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
- eprintln!("{url}");
+impl Db<sqlx::Sqlite> {
+ #[cfg(test)]
+ pub async fn sqlite(url: &str) -> Result<Self> {
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect(url)
.await?;
Ok(Self { pool })
}
+}
+
+impl Db<sqlx::Postgres> {
+ pub async fn postgres(url: &str, max_connection: u32) -> Result<Self> {
+ let pool = sqlx::postgres::PgPoolOptions::new()
+ .max_connections(1)
+ .connect(url)
+ .await?;
+ Ok(Self { pool })
+ }
+}
+impl<D> Db<D>
+where
+ D: sqlx::Database + sqlx::migrate::MigrateDatabase,
+ for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
+ D: for<'r> sqlx::database::HasValueRef<'r>,
+ D: for<'r> sqlx::database::HasArguments<'r>,
+ for<'a> &'a mut D::Connection: sqlx::Executor<'a>,
+ String: sqlx::Type<D>,
+ i32: sqlx::Type<D>,
+ bool: sqlx::Type<D>,
+ str: sqlx::Type<D>,
+ for<'a> str: sqlx::Encode<'a, D>,
+ for<'a> &'a str: sqlx::Encode<'a, D>,
+ for<'a> String: sqlx::Encode<'a, D>,
+ for<'a> i32: sqlx::Encode<'a, D>,
+ for<'a> bool: sqlx::Encode<'a, D>,
+ for<'a> Option<String>: sqlx::Encode<'a, D>,
+{
pub async fn migrate(
&self,
migrations_path: &Path,
@@ -266,13 +128,10 @@ impl RealDb {
result.push('%');
result
}
-}
-#[async_trait]
-impl Db for RealDb {
// users
- async fn create_user(
+ pub async fn create_user(
&self,
email_address: &str,
admin: bool,
@@ -302,7 +161,7 @@ impl Db for RealDb {
})
}
- async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
+ pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
Ok(sqlx::query_as(query)
@@ -313,7 +172,7 @@ impl Db for RealDb {
})
}
- async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
+ pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
test_support!(self, {
let like_string = Self::fuzzy_like_string(name_query);
let query = "
@@ -332,7 +191,7 @@ impl Db for RealDb {
})
}
- async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
test_support!(self, {
let query = "
SELECT users.*
@@ -347,7 +206,7 @@ impl Db for RealDb {
})
}
- async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
+ pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
test_support!(self, {
let query = "
SELECT metrics_id::text
@@ -361,7 +220,7 @@ impl Db for RealDb {
})
}
- async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+ pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
test_support!(self, {
let query = "
SELECT users.*
@@ -375,7 +234,10 @@ impl Db for RealDb {
})
}
- async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
+ pub async fn get_users_with_no_invites(
+ &self,
+ invited_by_another_user: bool,
+ ) -> Result<Vec<User>> {
test_support!(self, {
let query = format!(
"
@@ -391,7 +253,7 @@ impl Db for RealDb {
})
}
- async fn get_user_by_github_account(
+ pub async fn get_user_by_github_account(
&self,
github_login: &str,
github_user_id: Option<i32>,
@@ -443,7 +305,7 @@ impl Db for RealDb {
})
}
- async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+ pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
Ok(sqlx::query(query)
@@ -455,7 +317,7 @@ impl Db for RealDb {
})
}
- async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
+ pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
test_support!(self, {
let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
Ok(sqlx::query(query)
@@ -467,7 +329,7 @@ impl Db for RealDb {
})
}
- async fn destroy_user(&self, id: UserId) -> Result<()> {
+ pub async fn destroy_user(&self, id: UserId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM access_tokens WHERE user_id = $1;";
sqlx::query(query)
@@ -486,7 +348,7 @@ impl Db for RealDb {
// signups
- async fn create_signup(&self, signup: Signup) -> Result<()> {
+ pub async fn create_signup(&self, signup: Signup) -> Result<()> {
test_support!(self, {
sqlx::query(
"
@@ -522,7 +384,7 @@ impl Db for RealDb {
})
}
- async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
+ pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
test_support!(self, {
Ok(sqlx::query_as(
"
@@ -545,7 +407,7 @@ impl Db for RealDb {
})
}
- async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
+ pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
test_support!(self, {
Ok(sqlx::query_as(
"
@@ -564,28 +426,28 @@ impl Db for RealDb {
})
}
- async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
+ pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
test_support!(self, {
- // sqlx::query(
- // "
- // UPDATE signups
- // SET email_confirmation_sent = TRUE
- // WHERE email_address = ANY ($1)
- // ",
- // )
+ sqlx::query(
+ "
+ UPDATE signups
+ SET email_confirmation_sent = TRUE
+ WHERE email_address = ANY ($1)
+ ",
+ )
// .bind(
// &invites
// .iter()
// .map(|s| s.email_address.as_str())
// .collect::<Vec<_>>(),
// )
- // .execute(&self.pool)
- // .await?;
+ .execute(&self.pool)
+ .await?;
Ok(())
})
}
- async fn create_user_from_invite(
+ pub async fn create_user_from_invite(
&self,
invite: &Invite,
user: NewUserParams,
@@ -697,7 +559,7 @@ impl Db for RealDb {
// invite codes
- async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
+ pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
test_support!(self, {
let mut tx = self.pool.begin().await?;
if count > 0 {
@@ -730,7 +592,7 @@ impl Db for RealDb {
})
}
- async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
+ pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
test_support!(self, {
let result: Option<(String, i32)> = sqlx::query_as(
"
@@ -750,7 +612,7 @@ impl Db for RealDb {
})
}
- async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
+ pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
test_support!(self, {
sqlx::query_as(
"
@@ -771,7 +633,7 @@ impl Db for RealDb {
})
}
- async fn create_invite_from_code(
+ pub async fn create_invite_from_code(
&self,
code: &str,
email_address: &str,
@@ -860,7 +722,8 @@ impl Db for RealDb {
// projects
- async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
+ /// Registers a new project for the given user.
+ pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
test_support!(self, {
Ok(sqlx::query_scalar(
"
@@ -876,7 +739,8 @@ impl Db for RealDb {
})
}
- async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
+ /// Unregisters a project for the given project id.
+ pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
test_support!(self, {
sqlx::query(
"
@@ -892,7 +756,8 @@ impl Db for RealDb {
})
}
- async fn update_worktree_extensions(
+ /// Update file counts by extension for the given project and worktree.
+ pub async fn update_worktree_extensions(
&self,
project_id: ProjectId,
worktree_id: u64,
@@ -925,7 +790,8 @@ impl Db for RealDb {
})
}
- async fn get_project_extensions(
+ /// Get the file counts on the given project keyed by their worktree and extension.
+ pub async fn get_project_extensions(
&self,
project_id: ProjectId,
) -> Result<HashMap<u64, HashMap<String, usize>>> {
@@ -958,7 +824,9 @@ impl Db for RealDb {
})
}
- async fn record_user_activity(
+ /// Record which users have been active in which projects during
+ /// a given period of time.
+ pub async fn record_user_activity(
&self,
time_period: Range<OffsetDateTime>,
projects: &[(UserId, ProjectId)],
@@ -989,7 +857,9 @@ impl Db for RealDb {
})
}
- async fn get_active_user_count(
+ /// Get the number of users who have been active in the given
+ /// time period for at least the given time duration.
+ pub async fn get_active_user_count(
&self,
time_period: Range<OffsetDateTime>,
min_duration: Duration,
@@ -1066,7 +936,9 @@ impl Db for RealDb {
})
}
- async fn get_top_users_activity_summary(
+ /// Get the users that have been most active during the given time period,
+ /// along with the amount of time they have been active in each project.
+ pub async fn get_top_users_activity_summary(
&self,
time_period: Range<OffsetDateTime>,
max_user_count: usize,
@@ -1135,7 +1007,8 @@ impl Db for RealDb {
})
}
- async fn get_user_activity_timeline(
+ /// Get the project activity for the given user and time period.
+ pub async fn get_user_activity_timeline(
&self,
time_period: Range<OffsetDateTime>,
user_id: UserId,
@@ -1224,7 +1097,7 @@ impl Db for RealDb {
// contacts
- async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
+ pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
test_support!(self, {
let query = "
SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
@@ -1275,7 +1148,7 @@ impl Db for RealDb {
})
}
- async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
+ pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
test_support!(self, {
let (id_a, id_b) = if user_id_1 < user_id_2 {
(user_id_1, user_id_2)
@@ -1297,7 +1170,7 @@ impl Db for RealDb {
})
}
- async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
+ pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
test_support!(self, {
let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
(sender_id, receiver_id, true)
@@ -1331,7 +1204,7 @@ impl Db for RealDb {
})
}
- async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
+ pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
test_support!(self, {
let (id_a, id_b) = if responder_id < requester_id {
(responder_id, requester_id)
@@ -1356,7 +1229,7 @@ impl Db for RealDb {
})
}
- async fn dismiss_contact_notification(
+ pub async fn dismiss_contact_notification(
&self,
user_id: UserId,
contact_user_id: UserId,
@@ -1394,7 +1267,7 @@ impl Db for RealDb {
})
}
- async fn respond_to_contact_request(
+ pub async fn respond_to_contact_request(
&self,
responder_id: UserId,
requester_id: UserId,
@@ -1440,7 +1313,7 @@ impl Db for RealDb {
// access tokens
- async fn create_access_token_hash(
+ pub async fn create_access_token_hash(
&self,
user_id: UserId,
access_token_hash: &str,
@@ -1477,7 +1350,7 @@ impl Db for RealDb {
})
}
- async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+ pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
test_support!(self, {
let query = "
SELECT hash
@@ -1496,7 +1369,7 @@ impl Db for RealDb {
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
+ pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
test_support!(self, {
let query = "
SELECT *
@@ -1511,7 +1384,7 @@ impl Db for RealDb {
}
#[cfg(any(test, feature = "seed-support"))]
- async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
test_support!(self, {
let query = "
INSERT INTO orgs (name, slug)
@@ -1528,7 +1401,12 @@ impl Db for RealDb {
}
#[cfg(any(test, feature = "seed-support"))]
- async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
+ pub async fn add_org_member(
+ &self,
+ org_id: OrgId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
test_support!(self, {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
@@ -1548,7 +1426,7 @@ impl Db for RealDb {
// channels
#[cfg(any(test, feature = "seed-support"))]
- async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
test_support!(self, {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
@@ -1566,7 +1444,7 @@ impl Db for RealDb {
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+ pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT *
@@ -1582,7 +1460,7 @@ impl Db for RealDb {
})
}
- async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+ pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT
@@ -1600,7 +1478,7 @@ impl Db for RealDb {
})
}
- async fn can_user_access_channel(
+ pub async fn can_user_access_channel(
&self,
user_id: UserId,
channel_id: ChannelId,
@@ -1622,7 +1500,7 @@ impl Db for RealDb {
}
#[cfg(any(test, feature = "seed-support"))]
- async fn add_channel_member(
+ pub async fn add_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
@@ -1646,7 +1524,7 @@ impl Db for RealDb {
// messages
- async fn create_channel_message(
+ pub async fn create_channel_message(
&self,
channel_id: ChannelId,
sender_id: UserId,
@@ -1673,7 +1551,7 @@ impl Db for RealDb {
})
}
- async fn get_channel_messages(
+ pub async fn get_channel_messages(
&self,
channel_id: ChannelId,
count: usize,
@@ -1704,9 +1582,7 @@ impl Db for RealDb {
}
#[cfg(test)]
- async fn teardown(&self, url: &str) {
- let start = std::time::Instant::now();
- eprintln!("tearing down database...");
+ pub async fn teardown(&self, url: &str) {
test_support!(self, {
use util::ResultExt;
@@ -1720,14 +1596,8 @@ impl Db for RealDb {
<sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
.await
.log_err();
- eprintln!("tore down database: {:?}", start.elapsed());
})
}
-
- #[cfg(test)]
- fn as_fake(&self) -> Option<&FakeDb> {
- None
- }
}
macro_rules! id_type {
@@ -1937,661 +1807,13 @@ pub use test::*;
#[cfg(test)]
mod test {
use super::*;
- use anyhow::anyhow;
- use collections::BTreeMap;
use gpui::executor::Background;
- use parking_lot::Mutex;
use rand::prelude::*;
use sqlx::{migrate::MigrateDatabase, Sqlite};
use std::sync::Arc;
- use util::post_inc;
-
- pub struct FakeDb {
- background: Arc<Background>,
- pub users: Mutex<BTreeMap<UserId, User>>,
- pub projects: Mutex<BTreeMap<ProjectId, Project>>,
- pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
- pub orgs: Mutex<BTreeMap<OrgId, Org>>,
- pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
- pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
- pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
- pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
- pub contacts: Mutex<Vec<FakeContact>>,
- next_channel_message_id: Mutex<i32>,
- next_user_id: Mutex<i32>,
- next_org_id: Mutex<i32>,
- next_channel_id: Mutex<i32>,
- next_project_id: Mutex<i32>,
- }
-
- #[derive(Debug)]
- pub struct FakeContact {
- pub requester_id: UserId,
- pub responder_id: UserId,
- pub accepted: bool,
- pub should_notify: bool,
- }
-
- impl FakeDb {
- pub fn new(background: Arc<Background>) -> Self {
- Self {
- background,
- users: Default::default(),
- next_user_id: Mutex::new(0),
- projects: Default::default(),
- worktree_extensions: Default::default(),
- next_project_id: Mutex::new(1),
- orgs: Default::default(),
- next_org_id: Mutex::new(1),
- org_memberships: Default::default(),
- channels: Default::default(),
- next_channel_id: Mutex::new(1),
- channel_memberships: Default::default(),
- channel_messages: Default::default(),
- next_channel_message_id: Mutex::new(1),
- contacts: Default::default(),
- }
- }
- }
-
- #[async_trait]
- impl Db for FakeDb {
- async fn create_user(
- &self,
- email_address: &str,
- admin: bool,
- params: NewUserParams,
- ) -> Result<NewUserResult> {
- self.background.simulate_random_delay().await;
-
- let mut users = self.users.lock();
- let user_id = if let Some(user) = users
- .values()
- .find(|user| user.github_login == params.github_login)
- {
- user.id
- } else {
- let id = post_inc(&mut *self.next_user_id.lock());
- let user_id = UserId(id);
- users.insert(
- user_id,
- User {
- id: user_id,
- github_login: params.github_login,
- github_user_id: Some(params.github_user_id),
- email_address: Some(email_address.to_string()),
- admin,
- invite_code: None,
- invite_count: 0,
- connected_once: false,
- },
- );
- user_id
- };
- Ok(NewUserResult {
- user_id,
- metrics_id: "the-metrics-id".to_string(),
- inviting_user_id: None,
- signup_device_id: None,
- })
- }
-
- async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
- unimplemented!()
- }
-
- async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
- unimplemented!()
- }
-
- async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
- self.background.simulate_random_delay().await;
- Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
- }
-
- async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
- Ok("the-metrics-id".to_string())
- }
-
- async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
- self.background.simulate_random_delay().await;
- let users = self.users.lock();
- Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
- }
-
- async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
- unimplemented!()
- }
-
- async fn get_user_by_github_account(
- &self,
- github_login: &str,
- github_user_id: Option<i32>,
- ) -> Result<Option<User>> {
- self.background.simulate_random_delay().await;
- if let Some(github_user_id) = github_user_id {
- for user in self.users.lock().values_mut() {
- if user.github_user_id == Some(github_user_id) {
- user.github_login = github_login.into();
- return Ok(Some(user.clone()));
- }
- if user.github_login == github_login {
- user.github_user_id = Some(github_user_id);
- return Ok(Some(user.clone()));
- }
- }
- Ok(None)
- } else {
- Ok(self
- .users
- .lock()
- .values()
- .find(|user| user.github_login == github_login)
- .cloned())
- }
- }
-
- async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
- unimplemented!()
- }
-
- async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
- self.background.simulate_random_delay().await;
- let mut users = self.users.lock();
- let mut user = users
- .get_mut(&id)
- .ok_or_else(|| anyhow!("user not found"))?;
- user.connected_once = connected_once;
- Ok(())
- }
-
- async fn destroy_user(&self, _id: UserId) -> Result<()> {
- unimplemented!()
- }
-
- // signups
-
- async fn create_signup(&self, _signup: Signup) -> Result<()> {
- unimplemented!()
- }
-
- async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
- unimplemented!()
- }
-
- async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
- unimplemented!()
- }
-
- async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
- unimplemented!()
- }
-
- async fn create_user_from_invite(
- &self,
- _invite: &Invite,
- _user: NewUserParams,
- ) -> Result<Option<NewUserResult>> {
- unimplemented!()
- }
-
- // invite codes
-
- async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
- unimplemented!()
- }
-
- async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
- self.background.simulate_random_delay().await;
- Ok(None)
- }
-
- async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
- unimplemented!()
- }
-
- async fn create_invite_from_code(
- &self,
- _code: &str,
- _email_address: &str,
- _device_id: Option<&str>,
- ) -> Result<Invite> {
- unimplemented!()
- }
-
- // projects
-
- async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
- self.background.simulate_random_delay().await;
- if !self.users.lock().contains_key(&host_user_id) {
- Err(anyhow!("no such user"))?;
- }
-
- let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
- self.projects.lock().insert(
- project_id,
- Project {
- id: project_id,
- host_user_id,
- unregistered: false,
- },
- );
- Ok(project_id)
- }
-
- async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
- self.background.simulate_random_delay().await;
- self.projects
- .lock()
- .get_mut(&project_id)
- .ok_or_else(|| anyhow!("no such project"))?
- .unregistered = true;
- Ok(())
- }
-
- async fn update_worktree_extensions(
- &self,
- project_id: ProjectId,
- worktree_id: u64,
- extensions: HashMap<String, u32>,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- if !self.projects.lock().contains_key(&project_id) {
- Err(anyhow!("no such project"))?;
- }
-
- for (extension, count) in extensions {
- self.worktree_extensions
- .lock()
- .insert((project_id, worktree_id, extension), count);
- }
-
- Ok(())
- }
-
- async fn get_project_extensions(
- &self,
- _project_id: ProjectId,
- ) -> Result<HashMap<u64, HashMap<String, usize>>> {
- unimplemented!()
- }
-
- async fn record_user_activity(
- &self,
- _time_period: Range<OffsetDateTime>,
- _active_projects: &[(UserId, ProjectId)],
- ) -> Result<()> {
- unimplemented!()
- }
-
- async fn get_active_user_count(
- &self,
- _time_period: Range<OffsetDateTime>,
- _min_duration: Duration,
- _only_collaborative: bool,
- ) -> Result<usize> {
- unimplemented!()
- }
-
- async fn get_top_users_activity_summary(
- &self,
- _time_period: Range<OffsetDateTime>,
- _limit: usize,
- ) -> Result<Vec<UserActivitySummary>> {
- unimplemented!()
- }
-
- async fn get_user_activity_timeline(
- &self,
- _time_period: Range<OffsetDateTime>,
- _user_id: UserId,
- ) -> Result<Vec<UserActivityPeriod>> {
- unimplemented!()
- }
-
- // contacts
-
- async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
- self.background.simulate_random_delay().await;
- let mut contacts = Vec::new();
-
- for contact in self.contacts.lock().iter() {
- if contact.requester_id == id {
- if contact.accepted {
- contacts.push(Contact::Accepted {
- user_id: contact.responder_id,
- should_notify: contact.should_notify,
- });
- } else {
- contacts.push(Contact::Outgoing {
- user_id: contact.responder_id,
- });
- }
- } else if contact.responder_id == id {
- if contact.accepted {
- contacts.push(Contact::Accepted {
- user_id: contact.requester_id,
- should_notify: false,
- });
- } else {
- contacts.push(Contact::Incoming {
- user_id: contact.requester_id,
- should_notify: contact.should_notify,
- });
- }
- }
- }
-
- contacts.sort_unstable_by_key(|contact| contact.user_id());
- Ok(contacts)
- }
-
- async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
- self.background.simulate_random_delay().await;
- Ok(self.contacts.lock().iter().any(|contact| {
- contact.accepted
- && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
- || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
- }))
- }
-
- async fn send_contact_request(
- &self,
- requester_id: UserId,
- responder_id: UserId,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- let mut contacts = self.contacts.lock();
- for contact in contacts.iter_mut() {
- if contact.requester_id == requester_id && contact.responder_id == responder_id {
- if contact.accepted {
- Err(anyhow!("contact already exists"))?;
- } else {
- Err(anyhow!("contact already requested"))?;
- }
- }
- if contact.responder_id == requester_id && contact.requester_id == responder_id {
- if contact.accepted {
- Err(anyhow!("contact already exists"))?;
- } else {
- contact.accepted = true;
- contact.should_notify = false;
- return Ok(());
- }
- }
- }
- contacts.push(FakeContact {
- requester_id,
- responder_id,
- accepted: false,
- should_notify: true,
- });
- Ok(())
- }
-
- async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
- self.background.simulate_random_delay().await;
- self.contacts.lock().retain(|contact| {
- !(contact.requester_id == requester_id && contact.responder_id == responder_id)
- });
- Ok(())
- }
-
- async fn dismiss_contact_notification(
- &self,
- user_id: UserId,
- contact_user_id: UserId,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- let mut contacts = self.contacts.lock();
- for contact in contacts.iter_mut() {
- if contact.requester_id == contact_user_id
- && contact.responder_id == user_id
- && !contact.accepted
- {
- contact.should_notify = false;
- return Ok(());
- }
- if contact.requester_id == user_id
- && contact.responder_id == contact_user_id
- && contact.accepted
- {
- contact.should_notify = false;
- return Ok(());
- }
- }
- Err(anyhow!("no such notification"))?
- }
-
- async fn respond_to_contact_request(
- &self,
- responder_id: UserId,
- requester_id: UserId,
- accept: bool,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- let mut contacts = self.contacts.lock();
- for (ix, contact) in contacts.iter_mut().enumerate() {
- if contact.requester_id == requester_id && contact.responder_id == responder_id {
- if contact.accepted {
- Err(anyhow!("contact already confirmed"))?;
- }
- if accept {
- contact.accepted = true;
- contact.should_notify = true;
- } else {
- contacts.remove(ix);
- }
- return Ok(());
- }
- }
- Err(anyhow!("no such contact request"))?
- }
-
- async fn create_access_token_hash(
- &self,
- _user_id: UserId,
- _access_token_hash: &str,
- _max_access_token_count: usize,
- ) -> Result<()> {
- unimplemented!()
- }
-
- async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
- unimplemented!()
- }
-
- async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
- unimplemented!()
- }
-
- async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
- self.background.simulate_random_delay().await;
- let mut orgs = self.orgs.lock();
- if orgs.values().any(|org| org.slug == slug) {
- Err(anyhow!("org already exists"))?
- } else {
- let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
- orgs.insert(
- org_id,
- Org {
- id: org_id,
- name: name.to_string(),
- slug: slug.to_string(),
- },
- );
- Ok(org_id)
- }
- }
-
- async fn add_org_member(
- &self,
- org_id: OrgId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- if !self.orgs.lock().contains_key(&org_id) {
- Err(anyhow!("org does not exist"))?;
- }
- if !self.users.lock().contains_key(&user_id) {
- Err(anyhow!("user does not exist"))?;
- }
-
- self.org_memberships
- .lock()
- .entry((org_id, user_id))
- .or_insert(is_admin);
- Ok(())
- }
-
- async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
- self.background.simulate_random_delay().await;
- if !self.orgs.lock().contains_key(&org_id) {
- Err(anyhow!("org does not exist"))?;
- }
-
- let mut channels = self.channels.lock();
- let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
- channels.insert(
- channel_id,
- Channel {
- id: channel_id,
- name: name.to_string(),
- owner_id: org_id.0,
- owner_is_user: false,
- },
- );
- Ok(channel_id)
- }
-
- async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
- self.background.simulate_random_delay().await;
- Ok(self
- .channels
- .lock()
- .values()
- .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
- .cloned()
- .collect())
- }
-
- async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
- self.background.simulate_random_delay().await;
- let channels = self.channels.lock();
- let memberships = self.channel_memberships.lock();
- Ok(channels
- .values()
- .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
- .cloned()
- .collect())
- }
-
- async fn can_user_access_channel(
- &self,
- user_id: UserId,
- channel_id: ChannelId,
- ) -> Result<bool> {
- self.background.simulate_random_delay().await;
- Ok(self
- .channel_memberships
- .lock()
- .contains_key(&(channel_id, user_id)))
- }
-
- async fn add_channel_member(
- &self,
- channel_id: ChannelId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()> {
- self.background.simulate_random_delay().await;
- if !self.channels.lock().contains_key(&channel_id) {
- Err(anyhow!("channel does not exist"))?;
- }
- if !self.users.lock().contains_key(&user_id) {
- Err(anyhow!("user does not exist"))?;
- }
-
- self.channel_memberships
- .lock()
- .entry((channel_id, user_id))
- .or_insert(is_admin);
- Ok(())
- }
-
- async fn create_channel_message(
- &self,
- channel_id: ChannelId,
- sender_id: UserId,
- body: &str,
- timestamp: OffsetDateTime,
- nonce: u128,
- ) -> Result<MessageId> {
- self.background.simulate_random_delay().await;
- if !self.channels.lock().contains_key(&channel_id) {
- Err(anyhow!("channel does not exist"))?;
- }
- if !self.users.lock().contains_key(&sender_id) {
- Err(anyhow!("user does not exist"))?;
- }
-
- let mut messages = self.channel_messages.lock();
- if let Some(message) = messages
- .values()
- .find(|message| message.nonce.as_u128() == nonce)
- {
- Ok(message.id)
- } else {
- let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
- messages.insert(
- message_id,
- ChannelMessage {
- id: message_id,
- channel_id,
- sender_id,
- body: body.to_string(),
- sent_at: timestamp,
- nonce: Uuid::from_u128(nonce),
- },
- );
- Ok(message_id)
- }
- }
-
- async fn get_channel_messages(
- &self,
- channel_id: ChannelId,
- count: usize,
- before_id: Option<MessageId>,
- ) -> Result<Vec<ChannelMessage>> {
- self.background.simulate_random_delay().await;
- let mut messages = self
- .channel_messages
- .lock()
- .values()
- .rev()
- .filter(|message| {
- message.channel_id == channel_id
- && message.id < before_id.unwrap_or(MessageId::MAX)
- })
- .take(count)
- .cloned()
- .collect::<Vec<_>>();
- messages.sort_unstable_by_key(|message| message.id);
- Ok(messages)
- }
-
- async fn teardown(&self, _: &str) {}
-
- #[cfg(test)]
- fn as_fake(&self) -> Option<&FakeDb> {
- Some(self)
- }
- }
pub struct TestDb {
- pub db: Option<Arc<dyn Db>>,
+ pub db: Option<Arc<Db>>,
pub url: String,
}