test_db.rs

  1use super::*;
  2use gpui::executor::Background;
  3use parking_lot::Mutex;
  4use sea_orm::ConnectionTrait;
  5use sqlx::migrate::MigrateDatabase;
  6use std::sync::Arc;
  7
  8pub struct TestDb {
  9    pub db: Option<Arc<Database>>,
 10    pub connection: Option<sqlx::AnyConnection>,
 11}
 12
 13impl TestDb {
 14    pub fn sqlite(background: Arc<Background>) -> Self {
 15        let url = format!("sqlite::memory:");
 16        let runtime = tokio::runtime::Builder::new_current_thread()
 17            .enable_io()
 18            .enable_time()
 19            .build()
 20            .unwrap();
 21
 22        let mut db = runtime.block_on(async {
 23            let mut options = ConnectOptions::new(url);
 24            options.max_connections(5);
 25            let db = Database::new(options, Executor::Deterministic(background))
 26                .await
 27                .unwrap();
 28            let sql = include_str!(concat!(
 29                env!("CARGO_MANIFEST_DIR"),
 30                "/migrations.sqlite/20221109000000_test_schema.sql"
 31            ));
 32            db.pool
 33                .execute(sea_orm::Statement::from_string(
 34                    db.pool.get_database_backend(),
 35                    sql.into(),
 36                ))
 37                .await
 38                .unwrap();
 39            db
 40        });
 41
 42        db.runtime = Some(runtime);
 43
 44        Self {
 45            db: Some(Arc::new(db)),
 46            connection: None,
 47        }
 48    }
 49
 50    pub fn postgres(background: Arc<Background>) -> Self {
 51        static LOCK: Mutex<()> = Mutex::new(());
 52
 53        let _guard = LOCK.lock();
 54        let mut rng = StdRng::from_entropy();
 55        let url = format!(
 56            "postgres://postgres@localhost/zed-test-{}",
 57            rng.gen::<u128>()
 58        );
 59        let runtime = tokio::runtime::Builder::new_current_thread()
 60            .enable_io()
 61            .enable_time()
 62            .build()
 63            .unwrap();
 64
 65        let mut db = runtime.block_on(async {
 66            sqlx::Postgres::create_database(&url)
 67                .await
 68                .expect("failed to create test db");
 69            let mut options = ConnectOptions::new(url);
 70            options
 71                .max_connections(5)
 72                .idle_timeout(Duration::from_secs(0));
 73            let db = Database::new(options, Executor::Deterministic(background))
 74                .await
 75                .unwrap();
 76            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 77            db.migrate(Path::new(migrations_path), false).await.unwrap();
 78            db
 79        });
 80
 81        db.runtime = Some(runtime);
 82
 83        Self {
 84            db: Some(Arc::new(db)),
 85            connection: None,
 86        }
 87    }
 88
 89    pub fn db(&self) -> &Arc<Database> {
 90        self.db.as_ref().unwrap()
 91    }
 92}
 93
 94#[macro_export]
 95macro_rules! test_both_dbs {
 96    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
 97        #[gpui::test]
 98        async fn $postgres_test_name() {
 99            let test_db = crate::db::test_db::TestDb::postgres(
100                gpui::executor::Deterministic::new(0).build_background(),
101            );
102            $test_name(test_db.db()).await;
103        }
104
105        #[gpui::test]
106        async fn $sqlite_test_name() {
107            let test_db = crate::db::test_db::TestDb::sqlite(
108                gpui::executor::Deterministic::new(0).build_background(),
109            );
110            $test_name(test_db.db()).await;
111        }
112    };
113}
114
115impl Drop for TestDb {
116    fn drop(&mut self) {
117        let db = self.db.take().unwrap();
118        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
119            db.runtime.as_ref().unwrap().block_on(async {
120                use util::ResultExt;
121                let query = "
122                        SELECT pg_terminate_backend(pg_stat_activity.pid)
123                        FROM pg_stat_activity
124                        WHERE
125                            pg_stat_activity.datname = current_database() AND
126                            pid <> pg_backend_pid();
127                    ";
128                db.pool
129                    .execute(sea_orm::Statement::from_string(
130                        db.pool.get_database_backend(),
131                        query.into(),
132                    ))
133                    .await
134                    .log_err();
135                sqlx::Postgres::drop_database(db.options.get_url())
136                    .await
137                    .log_err();
138            })
139        }
140    }
141}