Run tests with an in-memory sqlite database

Max Brunsfeld created

Change summary

crates/collab/src/db.rs | 73 ++++++++++++++++++++++++++++++++++--------
1 file changed, 58 insertions(+), 15 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -22,6 +22,8 @@ pub struct Db<D: sqlx::Database> {
     pool: sqlx::Pool<D>,
     #[cfg(test)]
     background: Option<std::sync::Arc<gpui::executor::Background>>,
+    #[cfg(test)]
+    runtime: Option<tokio::runtime::Runtime>,
 }
 
 macro_rules! test_support {
@@ -35,7 +37,8 @@ macro_rules! test_support {
             if let Some(background) = $self.background.as_ref() {
                 background.simulate_random_delay().await;
             }
-            tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body)
+            #[cfg(test)]
+            $self.runtime.as_ref().unwrap().block_on(body)
         } else {
             body.await
         }
@@ -60,17 +63,29 @@ impl RowsAffected for sqlx::postgres::PgQueryResult {
 
 #[cfg(test)]
 impl Db<sqlx::Sqlite> {
+    const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
+
     pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
+        use std::str::FromStr as _;
+        let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
+            .unwrap()
+            .create_if_missing(true)
+            .shared_cache(true);
         let pool = sqlx::sqlite::SqlitePoolOptions::new()
+            .min_connections(2)
             .max_connections(max_connections)
-            .connect(url)
+            .connect_with(options)
             .await?;
         Ok(Self {
             pool,
             background: None,
+            runtime: None,
         })
     }
 
+    #[cfg(test)]
+    pub fn teardown(&self, _url: &str) {}
+
     pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
         test_support!(self, {
             let query = "
@@ -143,6 +158,8 @@ impl Db<sqlx::Sqlite> {
 }
 
 impl Db<sqlx::Postgres> {
+    const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
+
     pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
         let pool = sqlx::postgres::PgPoolOptions::new()
             .max_connections(max_connections)
@@ -152,6 +169,25 @@ impl Db<sqlx::Postgres> {
             pool,
             #[cfg(test)]
             background: None,
+            #[cfg(test)]
+            runtime: None,
+        })
+    }
+
+    #[cfg(test)]
+    pub fn teardown(&self, url: &str) {
+        self.runtime.as_ref().unwrap().block_on(async {
+            use util::ResultExt;
+            let query = "
+                SELECT pg_terminate_backend(pg_stat_activity.pid)
+                FROM pg_stat_activity
+                WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
+            ";
+            sqlx::query(query).execute(&self.pool).await.log_err();
+            self.pool.close().await;
+            <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
+                .await
+                .log_err();
         })
     }
 
@@ -1295,33 +1331,39 @@ mod test {
     use super::*;
     use gpui::executor::Background;
     use rand::prelude::*;
-    use sqlx::migrate::MigrateDatabase;
     use std::sync::Arc;
 
     pub struct TestDb {
         pub db: Option<Arc<DefaultDb>>,
+        pub conn: sqlx::sqlite::SqliteConnection,
         pub url: String,
     }
 
     impl TestDb {
         pub fn new(background: Arc<Background>) -> Self {
             let mut rng = StdRng::from_entropy();
-            let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
-            let db = tokio::runtime::Builder::new_current_thread()
+            let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
+            let runtime = tokio::runtime::Builder::new_current_thread()
                 .enable_io()
                 .enable_time()
                 .build()
-                .unwrap()
-                .block_on(async {
-                    sqlx::Sqlite::create_database(&url).await.unwrap();
-                    let mut db = DefaultDb::new(&url, 5).await.unwrap();
-                    db.background = Some(background);
-                    let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
-                    db.migrate(Path::new(migrations_path), false).await.unwrap();
-                    db
-                });
+                .unwrap();
+
+            let (mut db, conn) = runtime.block_on(async {
+                let db = DefaultDb::new(&url, 5).await.unwrap();
+                db.migrate(Path::new(DefaultDb::MIGRATIONS_PATH), false)
+                    .await
+                    .unwrap();
+                let conn = db.pool.acquire().await.unwrap().detach();
+                (db, conn)
+            });
+
+            db.background = Some(background);
+            db.runtime = Some(runtime);
+
             Self {
                 db: Some(Arc::new(db)),
+                conn,
                 url,
             }
         }
@@ -1333,7 +1375,8 @@ mod test {
 
     impl Drop for TestDb {
         fn drop(&mut self) {
-            std::fs::remove_file(&self.url).ok();
+            let db = self.db.take().unwrap();
+            db.teardown(&self.url);
         }
     }
 }