1use super::*;
2use gpui::executor::Background;
3use parking_lot::Mutex;
4use sea_orm::ConnectionTrait;
5use sqlx::migrate::MigrateDatabase;
6use std::sync::Arc;
7
8pub struct TestDb {
9 pub db: Option<Arc<Database>>,
10 pub connection: Option<sqlx::AnyConnection>,
11}
12
13impl TestDb {
14 pub fn sqlite(background: Arc<Background>) -> Self {
15 let url = format!("sqlite::memory:");
16 let runtime = tokio::runtime::Builder::new_current_thread()
17 .enable_io()
18 .enable_time()
19 .build()
20 .unwrap();
21
22 let mut db = runtime.block_on(async {
23 let mut options = ConnectOptions::new(url);
24 options.max_connections(5);
25 let db = Database::new(options, Executor::Deterministic(background))
26 .await
27 .unwrap();
28 let sql = include_str!(concat!(
29 env!("CARGO_MANIFEST_DIR"),
30 "/migrations.sqlite/20221109000000_test_schema.sql"
31 ));
32 db.pool
33 .execute(sea_orm::Statement::from_string(
34 db.pool.get_database_backend(),
35 sql.into(),
36 ))
37 .await
38 .unwrap();
39 db
40 });
41
42 db.runtime = Some(runtime);
43
44 Self {
45 db: Some(Arc::new(db)),
46 connection: None,
47 }
48 }
49
50 pub fn postgres(background: Arc<Background>) -> Self {
51 static LOCK: Mutex<()> = Mutex::new(());
52
53 let _guard = LOCK.lock();
54 let mut rng = StdRng::from_entropy();
55 let url = format!(
56 "postgres://postgres@localhost/zed-test-{}",
57 rng.gen::<u128>()
58 );
59 let runtime = tokio::runtime::Builder::new_current_thread()
60 .enable_io()
61 .enable_time()
62 .build()
63 .unwrap();
64
65 let mut db = runtime.block_on(async {
66 sqlx::Postgres::create_database(&url)
67 .await
68 .expect("failed to create test db");
69 let mut options = ConnectOptions::new(url);
70 options
71 .max_connections(5)
72 .idle_timeout(Duration::from_secs(0));
73 let db = Database::new(options, Executor::Deterministic(background))
74 .await
75 .unwrap();
76 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
77 db.migrate(Path::new(migrations_path), false).await.unwrap();
78 db
79 });
80
81 db.runtime = Some(runtime);
82
83 Self {
84 db: Some(Arc::new(db)),
85 connection: None,
86 }
87 }
88
89 pub fn db(&self) -> &Arc<Database> {
90 self.db.as_ref().unwrap()
91 }
92}
93
94#[macro_export]
95macro_rules! test_both_dbs {
96 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
97 #[gpui::test]
98 async fn $postgres_test_name() {
99 let test_db = crate::db::test_db::TestDb::postgres(
100 gpui::executor::Deterministic::new(0).build_background(),
101 );
102 $test_name(test_db.db()).await;
103 }
104
105 #[gpui::test]
106 async fn $sqlite_test_name() {
107 let test_db = crate::db::test_db::TestDb::sqlite(
108 gpui::executor::Deterministic::new(0).build_background(),
109 );
110 $test_name(test_db.db()).await;
111 }
112 };
113}
114
115impl Drop for TestDb {
116 fn drop(&mut self) {
117 let db = self.db.take().unwrap();
118 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
119 db.runtime.as_ref().unwrap().block_on(async {
120 use util::ResultExt;
121 let query = "
122 SELECT pg_terminate_backend(pg_stat_activity.pid)
123 FROM pg_stat_activity
124 WHERE
125 pg_stat_activity.datname = current_database() AND
126 pid <> pg_backend_pid();
127 ";
128 db.pool
129 .execute(sea_orm::Statement::from_string(
130 db.pool.get_database_backend(),
131 query.into(),
132 ))
133 .await
134 .log_err();
135 sqlx::Postgres::drop_database(db.options.get_url())
136 .await
137 .log_err();
138 })
139 }
140 }
141}