Clear test db pool whenever no dbs are in use

Max Brunsfeld created

Change summary

crates/server/src/db.rs | 55 +++++++++++++++++-------------------------
1 file changed, 22 insertions(+), 33 deletions(-)

Detailed changes

crates/server/src/db.rs 🔗

@@ -533,7 +533,11 @@ pub mod tests {
         migrate::{MigrateDatabase, Migrator},
         Postgres,
     };
-    use std::{mem, path::Path};
+    use std::{
+        mem,
+        path::Path,
+        sync::atomic::{AtomicUsize, Ordering::SeqCst},
+    };
     use util::ResultExt as _;
 
     pub struct TestDb {
@@ -543,37 +547,14 @@ pub mod tests {
     }
 
     lazy_static! {
-        static ref POOL: Mutex<Vec<TestDb>> = Default::default();
-    }
-
-    use std::os::raw::c_int;
-
-    extern "C" {
-        fn atexit(callback: extern "C" fn()) -> c_int;
-    }
-
-    #[ctor::ctor]
-    fn init() {
-        unsafe {
-            atexit(teardown_db_pool);
-        }
-    }
-
-    extern "C" fn teardown_db_pool() {
-        std::thread::spawn(|| {
-            block_on(async move {
-                for db in POOL.lock().drain(..) {
-                    db.teardown().await.log_err();
-                }
-            });
-        })
-        .join()
-        .log_err();
+        static ref DB_POOL: Mutex<Vec<TestDb>> = Default::default();
+        static ref DB_COUNT: AtomicUsize = Default::default();
     }
 
     impl TestDb {
         pub fn new() -> Self {
-            let mut pool = POOL.lock();
+            DB_COUNT.fetch_add(1, SeqCst);
+            let mut pool = DB_POOL.lock();
             if let Some(db) = pool.pop() {
                 db.truncate();
                 db
@@ -628,10 +609,10 @@ pub mod tests {
         async fn teardown(mut self) -> Result<()> {
             let db = self.db.take().unwrap();
             let query = "
-                    SELECT pg_terminate_backend(pg_stat_activity.pid)
-                    FROM pg_stat_activity
-                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
-                ";
+                SELECT pg_terminate_backend(pg_stat_activity.pid)
+                FROM pg_stat_activity
+                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+            ";
             sqlx::query(query)
                 .bind(&self.name)
                 .execute(&db.pool)
@@ -645,11 +626,19 @@ pub mod tests {
     impl Drop for TestDb {
         fn drop(&mut self) {
             if let Some(db) = self.db.take() {
-                POOL.lock().push(TestDb {
+                DB_POOL.lock().push(TestDb {
                     db: Some(db),
                     name: mem::take(&mut self.name),
                     url: mem::take(&mut self.url),
                 });
+                if DB_COUNT.fetch_sub(1, SeqCst) == 1 {
+                    block_on(async move {
+                        let mut pool = DB_POOL.lock();
+                        for db in pool.drain(..) {
+                            db.teardown().await.log_err();
+                        }
+                    });
+                }
             }
         }
     }