1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13mod user_tests;
14
15use crate::migrations::run_database_migrations;
16
17use super::*;
18use gpui::BackgroundExecutor;
19use parking_lot::Mutex;
20use sea_orm::ConnectionTrait;
21use sqlx::migrate::MigrateDatabase;
22use std::sync::{
23 Arc,
24 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
25};
26
27pub struct TestDb {
28 pub db: Option<Arc<Database>>,
29 pub connection: Option<sqlx::AnyConnection>,
30}
31
32impl TestDb {
33 pub fn sqlite(executor: BackgroundExecutor) -> Self {
34 let url = "sqlite::memory:";
35 let runtime = tokio::runtime::Builder::new_current_thread()
36 .enable_io()
37 .enable_time()
38 .build()
39 .unwrap();
40
41 let mut db = runtime.block_on(async {
42 let mut options = ConnectOptions::new(url);
43 options.max_connections(5);
44 let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
45 .await
46 .unwrap();
47 let sql = include_str!(concat!(
48 env!("CARGO_MANIFEST_DIR"),
49 "/migrations.sqlite/20221109000000_test_schema.sql"
50 ));
51 db.pool
52 .execute(sea_orm::Statement::from_string(
53 db.pool.get_database_backend(),
54 sql,
55 ))
56 .await
57 .unwrap();
58 db.initialize_notification_kinds().await.unwrap();
59 db
60 });
61
62 db.test_options = Some(DatabaseTestOptions {
63 runtime,
64 query_failure_probability: parking_lot::Mutex::new(0.0),
65 });
66
67 Self {
68 db: Some(Arc::new(db)),
69 connection: None,
70 }
71 }
72
73 pub fn postgres(executor: BackgroundExecutor) -> Self {
74 static LOCK: Mutex<()> = Mutex::new(());
75
76 let _guard = LOCK.lock();
77 let mut rng = StdRng::from_entropy();
78 let url = format!(
79 "postgres://postgres@localhost/zed-test-{}",
80 rng.r#gen::<u128>()
81 );
82 let runtime = tokio::runtime::Builder::new_current_thread()
83 .enable_io()
84 .enable_time()
85 .build()
86 .unwrap();
87
88 let mut db = runtime.block_on(async {
89 sqlx::Postgres::create_database(&url)
90 .await
91 .expect("failed to create test db");
92 let mut options = ConnectOptions::new(url);
93 options
94 .max_connections(5)
95 .idle_timeout(Duration::from_secs(0));
96 let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
97 .await
98 .unwrap();
99 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
100 run_database_migrations(db.options(), migrations_path)
101 .await
102 .unwrap();
103 db.initialize_notification_kinds().await.unwrap();
104 db
105 });
106
107 db.test_options = Some(DatabaseTestOptions {
108 runtime,
109 query_failure_probability: parking_lot::Mutex::new(0.0),
110 });
111
112 Self {
113 db: Some(Arc::new(db)),
114 connection: None,
115 }
116 }
117
118 pub fn db(&self) -> &Arc<Database> {
119 self.db.as_ref().unwrap()
120 }
121
122 pub fn set_query_failure_probability(&self, probability: f64) {
123 let database = self.db.as_ref().unwrap();
124 let test_options = database.test_options.as_ref().unwrap();
125 *test_options.query_failure_probability.lock() = probability;
126 }
127}
128
129#[macro_export]
130macro_rules! test_both_dbs {
131 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
132 #[cfg(target_os = "macos")]
133 #[gpui::test]
134 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
135 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
136 $test_name(test_db.db()).await;
137 }
138
139 #[gpui::test]
140 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
141 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
142 $test_name(test_db.db()).await;
143 }
144 };
145}
146
147impl Drop for TestDb {
148 fn drop(&mut self) {
149 let db = self.db.take().unwrap();
150 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
151 db.test_options.as_ref().unwrap().runtime.block_on(async {
152 use util::ResultExt;
153 let query = "
154 SELECT pg_terminate_backend(pg_stat_activity.pid)
155 FROM pg_stat_activity
156 WHERE
157 pg_stat_activity.datname = current_database() AND
158 pid <> pg_backend_pid();
159 ";
160 db.pool
161 .execute(sea_orm::Statement::from_string(
162 db.pool.get_database_backend(),
163 query,
164 ))
165 .await
166 .log_err();
167 sqlx::Postgres::drop_database(db.options.get_url())
168 .await
169 .log_err();
170 })
171 }
172 }
173}
174
175fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
176 channels
177 .iter()
178 .map(|(id, parent_path, name)| Channel {
179 id: *id,
180 name: name.to_string(),
181 visibility: ChannelVisibility::Members,
182 parent_path: parent_path.to_vec(),
183 })
184 .collect()
185}
186
187static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
188
189async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
190 db.create_user(
191 email,
192 None,
193 false,
194 NewUserParams {
195 github_login: email[0..email.find('@').unwrap()].to_string(),
196 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
197 },
198 )
199 .await
200 .unwrap()
201 .user_id
202}
203
204static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
205fn new_test_connection(server: ServerId) -> ConnectionId {
206 ConnectionId {
207 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
208 owner_id: server.0 as u32,
209 }
210}