1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9mod feature_flag_tests;
10mod user_tests;
11
12use crate::migrations::run_database_migrations;
13
14use super::*;
15use gpui::BackgroundExecutor;
16use parking_lot::Mutex;
17use rand::prelude::*;
18use sea_orm::ConnectionTrait;
19use sqlx::migrate::MigrateDatabase;
20use std::{
21 sync::{
22 Arc,
23 atomic::{AtomicI32, Ordering::SeqCst},
24 },
25 time::Duration,
26};
27
28pub struct TestDb {
29 pub db: Option<Arc<Database>>,
30 pub connection: Option<sqlx::AnyConnection>,
31}
32
33impl TestDb {
34 pub fn sqlite(executor: BackgroundExecutor) -> Self {
35 let url = "sqlite::memory:";
36 let runtime = tokio::runtime::Builder::new_current_thread()
37 .enable_io()
38 .enable_time()
39 .build()
40 .unwrap();
41
42 let mut db = runtime.block_on(async {
43 let mut options = ConnectOptions::new(url);
44 options.max_connections(5);
45 let mut db = Database::new(options).await.unwrap();
46 let sql = include_str!(concat!(
47 env!("CARGO_MANIFEST_DIR"),
48 "/migrations.sqlite/20221109000000_test_schema.sql"
49 ));
50 db.pool
51 .execute(sea_orm::Statement::from_string(
52 db.pool.get_database_backend(),
53 sql,
54 ))
55 .await
56 .unwrap();
57 db.initialize_notification_kinds().await.unwrap();
58 db
59 });
60
61 db.test_options = Some(DatabaseTestOptions {
62 executor,
63 runtime,
64 query_failure_probability: parking_lot::Mutex::new(0.0),
65 });
66
67 Self {
68 db: Some(Arc::new(db)),
69 connection: None,
70 }
71 }
72
73 pub fn postgres(executor: BackgroundExecutor) -> Self {
74 static LOCK: Mutex<()> = Mutex::new(());
75
76 let _guard = LOCK.lock();
77 let mut rng = StdRng::from_os_rng();
78 let url = format!(
79 "postgres://postgres@localhost/zed-test-{}",
80 rng.random::<u128>()
81 );
82 let runtime = tokio::runtime::Builder::new_current_thread()
83 .enable_io()
84 .enable_time()
85 .build()
86 .unwrap();
87
88 let mut db = runtime.block_on(async {
89 sqlx::Postgres::create_database(&url)
90 .await
91 .expect("failed to create test db");
92 let mut options = ConnectOptions::new(url);
93 options
94 .max_connections(5)
95 .idle_timeout(Duration::from_secs(0));
96 let mut db = Database::new(options).await.unwrap();
97 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
98 run_database_migrations(db.options(), migrations_path)
99 .await
100 .unwrap();
101 db.initialize_notification_kinds().await.unwrap();
102 db
103 });
104
105 db.test_options = Some(DatabaseTestOptions {
106 executor,
107 runtime,
108 query_failure_probability: parking_lot::Mutex::new(0.0),
109 });
110
111 Self {
112 db: Some(Arc::new(db)),
113 connection: None,
114 }
115 }
116
117 pub fn db(&self) -> &Arc<Database> {
118 self.db.as_ref().unwrap()
119 }
120
121 pub fn set_query_failure_probability(&self, probability: f64) {
122 let database = self.db.as_ref().unwrap();
123 let test_options = database.test_options.as_ref().unwrap();
124 *test_options.query_failure_probability.lock() = probability;
125 }
126}
127
128#[macro_export]
129macro_rules! test_both_dbs {
130 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
131 #[cfg(target_os = "macos")]
132 #[gpui::test]
133 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
134 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
135 $test_name(test_db.db()).await;
136 }
137
138 #[gpui::test]
139 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
140 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
141 $test_name(test_db.db()).await;
142 }
143 };
144}
145
146impl Drop for TestDb {
147 fn drop(&mut self) {
148 let db = self.db.take().unwrap();
149 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
150 db.test_options.as_ref().unwrap().runtime.block_on(async {
151 use util::ResultExt;
152 let query = "
153 SELECT pg_terminate_backend(pg_stat_activity.pid)
154 FROM pg_stat_activity
155 WHERE
156 pg_stat_activity.datname = current_database() AND
157 pid <> pg_backend_pid();
158 ";
159 db.pool
160 .execute(sea_orm::Statement::from_string(
161 db.pool.get_database_backend(),
162 query,
163 ))
164 .await
165 .log_err();
166 sqlx::Postgres::drop_database(db.options.get_url())
167 .await
168 .log_err();
169 })
170 }
171 }
172}
173
174#[track_caller]
175fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
176 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
177 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
178 pretty_assertions::assert_eq!(expected_channels, actual_channels);
179}
180
181fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
182 use std::collections::HashMap;
183
184 let mut result = Vec::new();
185 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
186
187 for (id, parent_path, name) in channels {
188 let parent_key = parent_path.to_vec();
189 let order = if parent_key.is_empty() {
190 1
191 } else {
192 *order_by_parent
193 .entry(parent_key.clone())
194 .and_modify(|e| *e += 1)
195 .or_insert(1)
196 };
197
198 result.push(Channel {
199 id: *id,
200 name: name.to_string(),
201 visibility: ChannelVisibility::Members,
202 parent_path: parent_key,
203 channel_order: order,
204 });
205 }
206
207 result
208}
209
210static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
211
212async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
213 db.create_user(
214 email,
215 None,
216 false,
217 NewUserParams {
218 github_login: email[0..email.find('@').unwrap()].to_string(),
219 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
220 },
221 )
222 .await
223 .unwrap()
224 .user_id
225}