diff --git a/server/src/db.rs b/server/src/db.rs index 2ae8fc8f1d310c9d61c30f87d394ab9d8c5a495d..7c063e140f3f7b5d5553a73249639562ebaab209 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -68,21 +68,6 @@ impl Db { }) } - #[cfg(test)] - pub fn test(url: &str, max_connections: u32) -> Self { - let mut db = block_on(Self::new(url, max_connections)).unwrap(); - db.test_mode = true; - db - } - - #[cfg(test)] - pub fn migrate(&self, path: &std::path::Path) { - block_on(async { - let migrator = sqlx::migrate::Migrator::new(path).await.unwrap(); - migrator.run(&self.db).await.unwrap(); - }); - } - // signups pub async fn create_signup( @@ -457,3 +442,54 @@ id_type!(OrgId); id_type!(ChannelId); id_type!(SignupId); id_type!(MessageId); + +#[cfg(test)] +pub mod tests { + use super::*; + use rand::prelude::*; + use sqlx::{ + migrate::{MigrateDatabase, Migrator}, + Postgres, + }; + use std::path::Path; + + pub struct TestDb { + pub name: String, + pub url: String, + } + + impl TestDb { + pub fn new() -> (Self, Db) { + // Enable tests to run in parallel by serializing the creation of each test database. + lazy_static::lazy_static! { + static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(()); + } + + let mut rng = StdRng::from_entropy(); + let name = format!("zed-test-{}", rng.gen::()); + let url = format!("postgres://postgres@localhost/{}", name); + let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); + let db = block_on(async { + { + let _lock = DB_CREATION.lock(); + Postgres::create_database(&url) + .await + .expect("failed to create test db"); + } + let mut db = Db::new(&url, 5).await.unwrap(); + db.test_mode = true; + let migrator = Migrator::new(migrations_path).await.unwrap(); + migrator.run(&db.db).await.unwrap(); + db + }); + + (Self { name, url }, db) + } + } + + impl Drop for TestDb { + fn drop(&mut self) { + block_on(Postgres::drop_database(&self.url)).unwrap(); + } + } +} diff --git a/server/src/rpc.rs b/server/src/rpc.rs index c869dd1aea385ccbcf608291a4534bf0dbd13c2c..0822ddbef1c42430b7d8f84937c566512e33d95c 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -919,18 +919,14 @@ mod tests { use super::*; use crate::{ auth, - db::{self, UserId}, + db::{tests::TestDb, Db, UserId}, github, AppState, Config, }; - use async_std::{ - sync::RwLockReadGuard, - task::{self, block_on}, - }; + use async_std::{sync::RwLockReadGuard, task}; use gpui::TestAppContext; use postage::mpsc; - use rand::prelude::*; use serde_json::json; - use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres}; + use sqlx::types::time::OffsetDateTime; use std::{path::Path, sync::Arc, time::Duration}; use zed::{ channel::{Channel, ChannelDetails, ChannelList}, @@ -1533,15 +1529,14 @@ mod tests { peer: Arc, app_state: Arc, server: Arc, - db_name: String, + test_db: TestDb, notifications: mpsc::Receiver<()>, } impl TestServer { async fn start() -> Self { - let mut rng = StdRng::from_entropy(); - let db_name = format!("zed-test-{}", rng.gen::()); - let app_state = Self::build_app_state(&db_name).await; + let (test_db, db) = TestDb::new(); + let app_state = Self::build_app_state(&test_db, db).await; let peer = Peer::new(); let notifications = mpsc::channel(128); let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0)); @@ -1549,7 +1544,7 @@ mod tests { peer, app_state, server, - db_name, + test_db, notifications: notifications.1, } } @@ -1575,18 +1570,10 @@ mod tests { (user_id, client) } - async fn build_app_state(db_name: &str) -> Arc { + async fn build_app_state(test_db: &TestDb, db: Db) -> Arc { let mut config = Config::default(); config.session_secret = "a".repeat(32); - config.database_url = format!("postgres://postgres@localhost/{}", db_name); - - Self::create_db(&config.database_url); - let db = db::Db::test(&config.database_url, 5); - db.migrate(Path::new(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations" - ))); - + config.database_url = test_db.url.clone(); let github_client = github::AppClient::test(); Arc::new(AppState { db, @@ -1598,16 +1585,6 @@ mod tests { }) } - fn create_db(url: &str) { - // Enable tests to run in parallel by serializing the creation of each test database. - lazy_static::lazy_static! { - static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(()); - } - - let _lock = DB_CREATION.lock(); - block_on(Postgres::create_database(url)).expect("failed to create test database"); - } - async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> { self.server.state.read().await } @@ -1630,10 +1607,7 @@ mod tests { fn drop(&mut self) { task::block_on(async { self.peer.reset().await; - self.app_state.db.close(&self.db_name).await; - Postgres::drop_database(&self.app_state.config.database_url) - .await - .unwrap(); + self.app_state.db.close(&self.test_db.name).await; }); } }