Extract a `TestDb` to setup and tear down a database instance

Antonio Scandurra created

Change summary

server/src/db.rs  | 66 +++++++++++++++++++++++++++++++++++++-----------
server/src/rpc.rs | 46 +++++++--------------------------
2 files changed, 61 insertions(+), 51 deletions(-)

Detailed changes

server/src/db.rs 🔗

@@ -68,21 +68,6 @@ impl Db {
         })
     }
 
-    #[cfg(test)]
-    pub fn test(url: &str, max_connections: u32) -> Self {
-        let mut db = block_on(Self::new(url, max_connections)).unwrap();
-        db.test_mode = true;
-        db
-    }
-
-    #[cfg(test)]
-    pub fn migrate(&self, path: &std::path::Path) {
-        block_on(async {
-            let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
-            migrator.run(&self.db).await.unwrap();
-        });
-    }
-
     // signups
 
     pub async fn create_signup(
@@ -457,3 +442,54 @@ id_type!(OrgId);
 id_type!(ChannelId);
 id_type!(SignupId);
 id_type!(MessageId);
+
+#[cfg(test)]
+pub mod tests {
+    use super::*;
+    use rand::prelude::*;
+    use sqlx::{
+        migrate::{MigrateDatabase, Migrator},
+        Postgres,
+    };
+    use std::path::Path;
+
+    pub struct TestDb {
+        pub name: String,
+        pub url: String,
+    }
+
+    impl TestDb {
+        pub fn new() -> (Self, Db) {
+            // Enable tests to run in parallel by serializing the creation of each test database.
+            lazy_static::lazy_static! {
+                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
+            }
+
+            let mut rng = StdRng::from_entropy();
+            let name = format!("zed-test-{}", rng.gen::<u128>());
+            let url = format!("postgres://postgres@localhost/{}", name);
+            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
+            let db = block_on(async {
+                {
+                    let _lock = DB_CREATION.lock();
+                    Postgres::create_database(&url)
+                        .await
+                        .expect("failed to create test db");
+                }
+                let mut db = Db::new(&url, 5).await.unwrap();
+                db.test_mode = true;
+                let migrator = Migrator::new(migrations_path).await.unwrap();
+                migrator.run(&db.db).await.unwrap();
+                db
+            });
+
+            (Self { name, url }, db)
+        }
+    }
+
+    impl Drop for TestDb {
+        fn drop(&mut self) {
+            block_on(Postgres::drop_database(&self.url)).unwrap();
+        }
+    }
+}

server/src/rpc.rs 🔗

@@ -919,18 +919,14 @@ mod tests {
     use super::*;
     use crate::{
         auth,
-        db::{self, UserId},
+        db::{tests::TestDb, Db, UserId},
         github, AppState, Config,
     };
-    use async_std::{
-        sync::RwLockReadGuard,
-        task::{self, block_on},
-    };
+    use async_std::{sync::RwLockReadGuard, task};
     use gpui::TestAppContext;
     use postage::mpsc;
-    use rand::prelude::*;
     use serde_json::json;
-    use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres};
+    use sqlx::types::time::OffsetDateTime;
     use std::{path::Path, sync::Arc, time::Duration};
     use zed::{
         channel::{Channel, ChannelDetails, ChannelList},
@@ -1533,15 +1529,14 @@ mod tests {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,
         server: Arc<Server>,
-        db_name: String,
+        test_db: TestDb,
         notifications: mpsc::Receiver<()>,
     }
 
     impl TestServer {
         async fn start() -> Self {
-            let mut rng = StdRng::from_entropy();
-            let db_name = format!("zed-test-{}", rng.gen::<u128>());
-            let app_state = Self::build_app_state(&db_name).await;
+            let (test_db, db) = TestDb::new();
+            let app_state = Self::build_app_state(&test_db, db).await;
             let peer = Peer::new();
             let notifications = mpsc::channel(128);
             let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
@@ -1549,7 +1544,7 @@ mod tests {
                 peer,
                 app_state,
                 server,
-                db_name,
+                test_db,
                 notifications: notifications.1,
             }
         }
@@ -1575,18 +1570,10 @@ mod tests {
             (user_id, client)
         }
 
-        async fn build_app_state(db_name: &str) -> Arc<AppState> {
+        async fn build_app_state(test_db: &TestDb, db: Db) -> Arc<AppState> {
             let mut config = Config::default();
             config.session_secret = "a".repeat(32);
-            config.database_url = format!("postgres://postgres@localhost/{}", db_name);
-
-            Self::create_db(&config.database_url);
-            let db = db::Db::test(&config.database_url, 5);
-            db.migrate(Path::new(concat!(
-                env!("CARGO_MANIFEST_DIR"),
-                "/migrations"
-            )));
-
+            config.database_url = test_db.url.clone();
             let github_client = github::AppClient::test();
             Arc::new(AppState {
                 db,
@@ -1598,16 +1585,6 @@ mod tests {
             })
         }
 
-        fn create_db(url: &str) {
-            // Enable tests to run in parallel by serializing the creation of each test database.
-            lazy_static::lazy_static! {
-                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
-            }
-
-            let _lock = DB_CREATION.lock();
-            block_on(Postgres::create_database(url)).expect("failed to create test database");
-        }
-
         async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
             self.server.state.read().await
         }
@@ -1630,10 +1607,7 @@ mod tests {
         fn drop(&mut self) {
             task::block_on(async {
                 self.peer.reset().await;
-                self.app_state.db.close(&self.db_name).await;
-                Postgres::drop_database(&self.app_state.config.database_url)
-                    .await
-                    .unwrap();
+                self.app_state.db.close(&self.test_db.name).await;
             });
         }
     }