1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9mod user_tests;
10
11use crate::migrations::run_database_migrations;
12
13use super::*;
14use gpui::BackgroundExecutor;
15use parking_lot::Mutex;
16use rand::prelude::*;
17use sea_orm::ConnectionTrait;
18use sqlx::migrate::MigrateDatabase;
19use std::{
20 sync::{
21 Arc,
22 atomic::{AtomicI32, Ordering::SeqCst},
23 },
24 time::Duration,
25};
26
27pub struct TestDb {
28 pub db: Option<Arc<Database>>,
29 pub connection: Option<sqlx::AnyConnection>,
30}
31
32impl TestDb {
33 pub fn sqlite(executor: BackgroundExecutor) -> Self {
34 let url = "sqlite::memory:";
35 let runtime = tokio::runtime::Builder::new_current_thread()
36 .enable_io()
37 .enable_time()
38 .build()
39 .unwrap();
40
41 let mut db = runtime.block_on(async {
42 let mut options = ConnectOptions::new(url);
43 options.max_connections(5);
44 let mut db = Database::new(options).await.unwrap();
45 let sql = include_str!(concat!(
46 env!("CARGO_MANIFEST_DIR"),
47 "/migrations.sqlite/20221109000000_test_schema.sql"
48 ));
49 db.pool
50 .execute(sea_orm::Statement::from_string(
51 db.pool.get_database_backend(),
52 sql,
53 ))
54 .await
55 .unwrap();
56 db.initialize_notification_kinds().await.unwrap();
57 db
58 });
59
60 db.test_options = Some(DatabaseTestOptions {
61 executor,
62 runtime,
63 query_failure_probability: parking_lot::Mutex::new(0.0),
64 });
65
66 Self {
67 db: Some(Arc::new(db)),
68 connection: None,
69 }
70 }
71
72 pub fn postgres(executor: BackgroundExecutor) -> Self {
73 static LOCK: Mutex<()> = Mutex::new(());
74
75 let _guard = LOCK.lock();
76 let mut rng = StdRng::from_os_rng();
77 let url = format!(
78 "postgres://postgres@localhost/zed-test-{}",
79 rng.random::<u128>()
80 );
81 let runtime = tokio::runtime::Builder::new_current_thread()
82 .enable_io()
83 .enable_time()
84 .build()
85 .unwrap();
86
87 let mut db = runtime.block_on(async {
88 sqlx::Postgres::create_database(&url)
89 .await
90 .expect("failed to create test db");
91 let mut options = ConnectOptions::new(url);
92 options
93 .max_connections(5)
94 .idle_timeout(Duration::from_secs(0));
95 let mut db = Database::new(options).await.unwrap();
96 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
97 run_database_migrations(db.options(), migrations_path)
98 .await
99 .unwrap();
100 db.initialize_notification_kinds().await.unwrap();
101 db
102 });
103
104 db.test_options = Some(DatabaseTestOptions {
105 executor,
106 runtime,
107 query_failure_probability: parking_lot::Mutex::new(0.0),
108 });
109
110 Self {
111 db: Some(Arc::new(db)),
112 connection: None,
113 }
114 }
115
116 pub fn db(&self) -> &Arc<Database> {
117 self.db.as_ref().unwrap()
118 }
119
120 pub fn set_query_failure_probability(&self, probability: f64) {
121 let database = self.db.as_ref().unwrap();
122 let test_options = database.test_options.as_ref().unwrap();
123 *test_options.query_failure_probability.lock() = probability;
124 }
125}
126
127#[macro_export]
128macro_rules! test_both_dbs {
129 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
130 #[cfg(target_os = "macos")]
131 #[gpui::test]
132 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
133 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
134 $test_name(test_db.db()).await;
135 }
136
137 #[gpui::test]
138 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
139 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
140 $test_name(test_db.db()).await;
141 }
142 };
143}
144
145impl Drop for TestDb {
146 fn drop(&mut self) {
147 let db = self.db.take().unwrap();
148 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
149 db.test_options.as_ref().unwrap().runtime.block_on(async {
150 use util::ResultExt;
151 let query = "
152 SELECT pg_terminate_backend(pg_stat_activity.pid)
153 FROM pg_stat_activity
154 WHERE
155 pg_stat_activity.datname = current_database() AND
156 pid <> pg_backend_pid();
157 ";
158 db.pool
159 .execute(sea_orm::Statement::from_string(
160 db.pool.get_database_backend(),
161 query,
162 ))
163 .await
164 .log_err();
165 sqlx::Postgres::drop_database(db.options.get_url())
166 .await
167 .log_err();
168 })
169 }
170 }
171}
172
173#[track_caller]
174fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
175 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
176 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
177 pretty_assertions::assert_eq!(expected_channels, actual_channels);
178}
179
180fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
181 use std::collections::HashMap;
182
183 let mut result = Vec::new();
184 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
185
186 for (id, parent_path, name) in channels {
187 let parent_key = parent_path.to_vec();
188 let order = if parent_key.is_empty() {
189 1
190 } else {
191 *order_by_parent
192 .entry(parent_key.clone())
193 .and_modify(|e| *e += 1)
194 .or_insert(1)
195 };
196
197 result.push(Channel {
198 id: *id,
199 name: name.to_string(),
200 visibility: ChannelVisibility::Members,
201 parent_path: parent_key,
202 channel_order: order,
203 });
204 }
205
206 result
207}
208
209static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
210
211async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
212 db.create_user(
213 email,
214 None,
215 false,
216 NewUserParams {
217 github_login: email[0..email.find('@').unwrap()].to_string(),
218 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
219 },
220 )
221 .await
222 .unwrap()
223 .user_id
224}