1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5mod extension_tests;
6mod migrations;
7
8use std::sync::Arc;
9use std::sync::atomic::{AtomicI32, Ordering::SeqCst};
10use std::time::Duration;
11
12use gpui::BackgroundExecutor;
13use parking_lot::Mutex;
14use rand::prelude::*;
15use sea_orm::ConnectionTrait;
16use sqlx::migrate::MigrateDatabase;
17
18use self::migrations::run_database_migrations;
19
20use super::*;
21
22pub struct TestDb {
23 pub db: Option<Arc<Database>>,
24 pub connection: Option<sqlx::AnyConnection>,
25}
26
27impl TestDb {
28 pub fn sqlite(executor: BackgroundExecutor) -> Self {
29 let url = "sqlite::memory:";
30 let runtime = tokio::runtime::Builder::new_current_thread()
31 .enable_io()
32 .enable_time()
33 .build()
34 .unwrap();
35
36 let mut db = runtime.block_on(async {
37 let mut options = ConnectOptions::new(url);
38 options.max_connections(5);
39 let mut db = Database::new(options).await.unwrap();
40 let sql = include_str!(concat!(
41 env!("CARGO_MANIFEST_DIR"),
42 "/migrations.sqlite/20221109000000_test_schema.sql"
43 ));
44 db.pool
45 .execute(sea_orm::Statement::from_string(
46 db.pool.get_database_backend(),
47 sql,
48 ))
49 .await
50 .unwrap();
51 db.initialize_notification_kinds().await.unwrap();
52 db
53 });
54
55 db.test_options = Some(DatabaseTestOptions {
56 executor,
57 runtime,
58 query_failure_probability: parking_lot::Mutex::new(0.0),
59 });
60
61 Self {
62 db: Some(Arc::new(db)),
63 connection: None,
64 }
65 }
66
67 pub fn postgres(executor: BackgroundExecutor) -> Self {
68 static LOCK: Mutex<()> = Mutex::new(());
69
70 let _guard = LOCK.lock();
71 let mut rng = StdRng::from_os_rng();
72 let url = format!(
73 "postgres://postgres@localhost/zed-test-{}",
74 rng.random::<u128>()
75 );
76 let runtime = tokio::runtime::Builder::new_current_thread()
77 .enable_io()
78 .enable_time()
79 .build()
80 .unwrap();
81
82 let mut db = runtime.block_on(async {
83 sqlx::Postgres::create_database(&url)
84 .await
85 .expect("failed to create test db");
86 let mut options = ConnectOptions::new(url);
87 options
88 .max_connections(5)
89 .idle_timeout(Duration::from_secs(0));
90 let mut db = Database::new(options).await.unwrap();
91 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
92 run_database_migrations(db.options(), migrations_path)
93 .await
94 .unwrap();
95 db.initialize_notification_kinds().await.unwrap();
96 db
97 });
98
99 db.test_options = Some(DatabaseTestOptions {
100 executor,
101 runtime,
102 query_failure_probability: parking_lot::Mutex::new(0.0),
103 });
104
105 Self {
106 db: Some(Arc::new(db)),
107 connection: None,
108 }
109 }
110
111 pub fn db(&self) -> &Arc<Database> {
112 self.db.as_ref().unwrap()
113 }
114
115 pub fn set_query_failure_probability(&self, probability: f64) {
116 let database = self.db.as_ref().unwrap();
117 let test_options = database.test_options.as_ref().unwrap();
118 *test_options.query_failure_probability.lock() = probability;
119 }
120}
121
122#[macro_export]
123macro_rules! test_both_dbs {
124 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
125 #[cfg(target_os = "macos")]
126 #[gpui::test]
127 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
128 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
129 $test_name(test_db.db()).await;
130 }
131
132 #[gpui::test]
133 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
134 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
135 $test_name(test_db.db()).await;
136 }
137 };
138}
139
140impl Drop for TestDb {
141 fn drop(&mut self) {
142 let db = self.db.take().unwrap();
143 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
144 db.test_options.as_ref().unwrap().runtime.block_on(async {
145 use util::ResultExt;
146 let query = "
147 SELECT pg_terminate_backend(pg_stat_activity.pid)
148 FROM pg_stat_activity
149 WHERE
150 pg_stat_activity.datname = current_database() AND
151 pid <> pg_backend_pid();
152 ";
153 db.pool
154 .execute(sea_orm::Statement::from_string(
155 db.pool.get_database_backend(),
156 query,
157 ))
158 .await
159 .log_err();
160 sqlx::Postgres::drop_database(db.options.get_url())
161 .await
162 .log_err();
163 })
164 }
165 }
166}
167
168#[track_caller]
169fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
170 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
171 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
172 pretty_assertions::assert_eq!(expected_channels, actual_channels);
173}
174
175fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
176 use std::collections::HashMap;
177
178 let mut result = Vec::new();
179 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
180
181 for (id, parent_path, name) in channels {
182 let parent_key = parent_path.to_vec();
183 let order = if parent_key.is_empty() {
184 1
185 } else {
186 *order_by_parent
187 .entry(parent_key.clone())
188 .and_modify(|e| *e += 1)
189 .or_insert(1)
190 };
191
192 result.push(Channel {
193 id: *id,
194 name: (*name).to_owned(),
195 visibility: ChannelVisibility::Members,
196 parent_path: parent_key,
197 channel_order: order,
198 });
199 }
200
201 result
202}
203
204static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
205
206async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
207 db.create_user(
208 email,
209 None,
210 false,
211 NewUserParams {
212 github_login: email[0..email.find('@').unwrap()].to_string(),
213 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
214 },
215 )
216 .await
217 .unwrap()
218 .user_id
219}