1mod buffer_tests;
2mod channel_tests;
3mod db_tests;
4mod extension_tests;
5mod migrations;
6
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI32, Ordering::SeqCst};
9use std::time::Duration;
10
11use gpui::BackgroundExecutor;
12use parking_lot::Mutex;
13use rand::prelude::*;
14use sea_orm::ConnectionTrait;
15use sqlx::migrate::MigrateDatabase;
16
17use self::migrations::run_database_migrations;
18
19use super::*;
20
21pub struct TestDb {
22 pub db: Option<Arc<Database>>,
23 pub connection: Option<sqlx::AnyConnection>,
24}
25
26impl TestDb {
27 pub fn sqlite(executor: BackgroundExecutor) -> Self {
28 let url = "sqlite::memory:";
29 let runtime = tokio::runtime::Builder::new_current_thread()
30 .enable_io()
31 .enable_time()
32 .build()
33 .unwrap();
34
35 let mut db = runtime.block_on(async {
36 let mut options = ConnectOptions::new(url);
37 options.max_connections(5);
38 let mut db = Database::new(options).await.unwrap();
39 let sql = include_str!(concat!(
40 env!("CARGO_MANIFEST_DIR"),
41 "/migrations.sqlite/20221109000000_test_schema.sql"
42 ));
43 db.pool
44 .execute(sea_orm::Statement::from_string(
45 db.pool.get_database_backend(),
46 sql,
47 ))
48 .await
49 .unwrap();
50 db.initialize_notification_kinds().await.unwrap();
51 db
52 });
53
54 db.test_options = Some(DatabaseTestOptions {
55 executor,
56 runtime,
57 query_failure_probability: parking_lot::Mutex::new(0.0),
58 });
59
60 Self {
61 db: Some(Arc::new(db)),
62 connection: None,
63 }
64 }
65
66 pub fn postgres(executor: BackgroundExecutor) -> Self {
67 static LOCK: Mutex<()> = Mutex::new(());
68
69 let _guard = LOCK.lock();
70 let mut rng = StdRng::from_os_rng();
71 let url = format!(
72 "postgres://postgres@localhost/zed-test-{}",
73 rng.random::<u128>()
74 );
75 let runtime = tokio::runtime::Builder::new_current_thread()
76 .enable_io()
77 .enable_time()
78 .build()
79 .unwrap();
80
81 let mut db = runtime.block_on(async {
82 sqlx::Postgres::create_database(&url)
83 .await
84 .expect("failed to create test db");
85 let mut options = ConnectOptions::new(url);
86 options
87 .max_connections(5)
88 .idle_timeout(Duration::from_secs(0));
89 let mut db = Database::new(options).await.unwrap();
90 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
91 run_database_migrations(db.options(), migrations_path)
92 .await
93 .unwrap();
94 db.initialize_notification_kinds().await.unwrap();
95 db
96 });
97
98 db.test_options = Some(DatabaseTestOptions {
99 executor,
100 runtime,
101 query_failure_probability: parking_lot::Mutex::new(0.0),
102 });
103
104 Self {
105 db: Some(Arc::new(db)),
106 connection: None,
107 }
108 }
109
110 pub fn db(&self) -> &Arc<Database> {
111 self.db.as_ref().unwrap()
112 }
113
114 pub fn set_query_failure_probability(&self, probability: f64) {
115 let database = self.db.as_ref().unwrap();
116 let test_options = database.test_options.as_ref().unwrap();
117 *test_options.query_failure_probability.lock() = probability;
118 }
119}
120
121#[macro_export]
122macro_rules! test_both_dbs {
123 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
124 #[cfg(target_os = "macos")]
125 #[gpui::test]
126 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
127 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
128 $test_name(test_db.db()).await;
129 }
130
131 #[gpui::test]
132 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
133 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
134 $test_name(test_db.db()).await;
135 }
136 };
137}
138
139impl Drop for TestDb {
140 fn drop(&mut self) {
141 let db = self.db.take().unwrap();
142 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
143 db.test_options.as_ref().unwrap().runtime.block_on(async {
144 use util::ResultExt;
145 let query = "
146 SELECT pg_terminate_backend(pg_stat_activity.pid)
147 FROM pg_stat_activity
148 WHERE
149 pg_stat_activity.datname = current_database() AND
150 pid <> pg_backend_pid();
151 ";
152 db.pool
153 .execute(sea_orm::Statement::from_string(
154 db.pool.get_database_backend(),
155 query,
156 ))
157 .await
158 .log_err();
159 sqlx::Postgres::drop_database(db.options.get_url())
160 .await
161 .log_err();
162 })
163 }
164 }
165}
166
167#[track_caller]
168fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
169 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
170 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
171 pretty_assertions::assert_eq!(expected_channels, actual_channels);
172}
173
174fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
175 use std::collections::HashMap;
176
177 let mut result = Vec::new();
178 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
179
180 for (id, parent_path, name) in channels {
181 let parent_key = parent_path.to_vec();
182 let order = if parent_key.is_empty() {
183 1
184 } else {
185 *order_by_parent
186 .entry(parent_key.clone())
187 .and_modify(|e| *e += 1)
188 .or_insert(1)
189 };
190
191 result.push(Channel {
192 id: *id,
193 name: (*name).to_owned(),
194 visibility: ChannelVisibility::Members,
195 parent_path: parent_key,
196 channel_order: order,
197 });
198 }
199
200 result
201}
202
203static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
204
205async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
206 db.create_user(
207 email,
208 None,
209 false,
210 NewUserParams {
211 github_login: email[0..email.find('@').unwrap()].to_string(),
212 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
213 },
214 )
215 .await
216 .unwrap()
217 .user_id
218}