1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5mod extension_tests;
6
7use crate::migrations::run_database_migrations;
8
9use super::*;
10use gpui::BackgroundExecutor;
11use parking_lot::Mutex;
12use rand::prelude::*;
13use sea_orm::ConnectionTrait;
14use sqlx::migrate::MigrateDatabase;
15use std::{
16 sync::{
17 Arc,
18 atomic::{AtomicI32, Ordering::SeqCst},
19 },
20 time::Duration,
21};
22
23pub struct TestDb {
24 pub db: Option<Arc<Database>>,
25 pub connection: Option<sqlx::AnyConnection>,
26}
27
28impl TestDb {
29 pub fn sqlite(executor: BackgroundExecutor) -> Self {
30 let url = "sqlite::memory:";
31 let runtime = tokio::runtime::Builder::new_current_thread()
32 .enable_io()
33 .enable_time()
34 .build()
35 .unwrap();
36
37 let mut db = runtime.block_on(async {
38 let mut options = ConnectOptions::new(url);
39 options.max_connections(5);
40 let mut db = Database::new(options).await.unwrap();
41 let sql = include_str!(concat!(
42 env!("CARGO_MANIFEST_DIR"),
43 "/migrations.sqlite/20221109000000_test_schema.sql"
44 ));
45 db.pool
46 .execute(sea_orm::Statement::from_string(
47 db.pool.get_database_backend(),
48 sql,
49 ))
50 .await
51 .unwrap();
52 db.initialize_notification_kinds().await.unwrap();
53 db
54 });
55
56 db.test_options = Some(DatabaseTestOptions {
57 executor,
58 runtime,
59 query_failure_probability: parking_lot::Mutex::new(0.0),
60 });
61
62 Self {
63 db: Some(Arc::new(db)),
64 connection: None,
65 }
66 }
67
68 pub fn postgres(executor: BackgroundExecutor) -> Self {
69 static LOCK: Mutex<()> = Mutex::new(());
70
71 let _guard = LOCK.lock();
72 let mut rng = StdRng::from_os_rng();
73 let url = format!(
74 "postgres://postgres@localhost/zed-test-{}",
75 rng.random::<u128>()
76 );
77 let runtime = tokio::runtime::Builder::new_current_thread()
78 .enable_io()
79 .enable_time()
80 .build()
81 .unwrap();
82
83 let mut db = runtime.block_on(async {
84 sqlx::Postgres::create_database(&url)
85 .await
86 .expect("failed to create test db");
87 let mut options = ConnectOptions::new(url);
88 options
89 .max_connections(5)
90 .idle_timeout(Duration::from_secs(0));
91 let mut db = Database::new(options).await.unwrap();
92 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
93 run_database_migrations(db.options(), migrations_path)
94 .await
95 .unwrap();
96 db.initialize_notification_kinds().await.unwrap();
97 db
98 });
99
100 db.test_options = Some(DatabaseTestOptions {
101 executor,
102 runtime,
103 query_failure_probability: parking_lot::Mutex::new(0.0),
104 });
105
106 Self {
107 db: Some(Arc::new(db)),
108 connection: None,
109 }
110 }
111
112 pub fn db(&self) -> &Arc<Database> {
113 self.db.as_ref().unwrap()
114 }
115
116 pub fn set_query_failure_probability(&self, probability: f64) {
117 let database = self.db.as_ref().unwrap();
118 let test_options = database.test_options.as_ref().unwrap();
119 *test_options.query_failure_probability.lock() = probability;
120 }
121}
122
123#[macro_export]
124macro_rules! test_both_dbs {
125 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
126 #[cfg(target_os = "macos")]
127 #[gpui::test]
128 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
129 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
130 $test_name(test_db.db()).await;
131 }
132
133 #[gpui::test]
134 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
135 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
136 $test_name(test_db.db()).await;
137 }
138 };
139}
140
141impl Drop for TestDb {
142 fn drop(&mut self) {
143 let db = self.db.take().unwrap();
144 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
145 db.test_options.as_ref().unwrap().runtime.block_on(async {
146 use util::ResultExt;
147 let query = "
148 SELECT pg_terminate_backend(pg_stat_activity.pid)
149 FROM pg_stat_activity
150 WHERE
151 pg_stat_activity.datname = current_database() AND
152 pid <> pg_backend_pid();
153 ";
154 db.pool
155 .execute(sea_orm::Statement::from_string(
156 db.pool.get_database_backend(),
157 query,
158 ))
159 .await
160 .log_err();
161 sqlx::Postgres::drop_database(db.options.get_url())
162 .await
163 .log_err();
164 })
165 }
166 }
167}
168
169#[track_caller]
170fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
171 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
172 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
173 pretty_assertions::assert_eq!(expected_channels, actual_channels);
174}
175
176fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
177 use std::collections::HashMap;
178
179 let mut result = Vec::new();
180 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
181
182 for (id, parent_path, name) in channels {
183 let parent_key = parent_path.to_vec();
184 let order = if parent_key.is_empty() {
185 1
186 } else {
187 *order_by_parent
188 .entry(parent_key.clone())
189 .and_modify(|e| *e += 1)
190 .or_insert(1)
191 };
192
193 result.push(Channel {
194 id: *id,
195 name: (*name).to_owned(),
196 visibility: ChannelVisibility::Members,
197 parent_path: parent_key,
198 channel_order: order,
199 });
200 }
201
202 result
203}
204
205static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
206
207async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
208 db.create_user(
209 email,
210 None,
211 false,
212 NewUserParams {
213 github_login: email[0..email.find('@').unwrap()].to_string(),
214 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
215 },
216 )
217 .await
218 .unwrap()
219 .user_id
220}