1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9
10use crate::migrations::run_database_migrations;
11
12use super::*;
13use gpui::BackgroundExecutor;
14use parking_lot::Mutex;
15use rand::prelude::*;
16use sea_orm::ConnectionTrait;
17use sqlx::migrate::MigrateDatabase;
18use std::{
19 sync::{
20 Arc,
21 atomic::{AtomicI32, Ordering::SeqCst},
22 },
23 time::Duration,
24};
25
26pub struct TestDb {
27 pub db: Option<Arc<Database>>,
28 pub connection: Option<sqlx::AnyConnection>,
29}
30
31impl TestDb {
32 pub fn sqlite(executor: BackgroundExecutor) -> Self {
33 let url = "sqlite::memory:";
34 let runtime = tokio::runtime::Builder::new_current_thread()
35 .enable_io()
36 .enable_time()
37 .build()
38 .unwrap();
39
40 let mut db = runtime.block_on(async {
41 let mut options = ConnectOptions::new(url);
42 options.max_connections(5);
43 let mut db = Database::new(options).await.unwrap();
44 let sql = include_str!(concat!(
45 env!("CARGO_MANIFEST_DIR"),
46 "/migrations.sqlite/20221109000000_test_schema.sql"
47 ));
48 db.pool
49 .execute(sea_orm::Statement::from_string(
50 db.pool.get_database_backend(),
51 sql,
52 ))
53 .await
54 .unwrap();
55 db.initialize_notification_kinds().await.unwrap();
56 db
57 });
58
59 db.test_options = Some(DatabaseTestOptions {
60 executor,
61 runtime,
62 query_failure_probability: parking_lot::Mutex::new(0.0),
63 });
64
65 Self {
66 db: Some(Arc::new(db)),
67 connection: None,
68 }
69 }
70
71 pub fn postgres(executor: BackgroundExecutor) -> Self {
72 static LOCK: Mutex<()> = Mutex::new(());
73
74 let _guard = LOCK.lock();
75 let mut rng = StdRng::from_os_rng();
76 let url = format!(
77 "postgres://postgres@localhost/zed-test-{}",
78 rng.random::<u128>()
79 );
80 let runtime = tokio::runtime::Builder::new_current_thread()
81 .enable_io()
82 .enable_time()
83 .build()
84 .unwrap();
85
86 let mut db = runtime.block_on(async {
87 sqlx::Postgres::create_database(&url)
88 .await
89 .expect("failed to create test db");
90 let mut options = ConnectOptions::new(url);
91 options
92 .max_connections(5)
93 .idle_timeout(Duration::from_secs(0));
94 let mut db = Database::new(options).await.unwrap();
95 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
96 run_database_migrations(db.options(), migrations_path)
97 .await
98 .unwrap();
99 db.initialize_notification_kinds().await.unwrap();
100 db
101 });
102
103 db.test_options = Some(DatabaseTestOptions {
104 executor,
105 runtime,
106 query_failure_probability: parking_lot::Mutex::new(0.0),
107 });
108
109 Self {
110 db: Some(Arc::new(db)),
111 connection: None,
112 }
113 }
114
115 pub fn db(&self) -> &Arc<Database> {
116 self.db.as_ref().unwrap()
117 }
118
119 pub fn set_query_failure_probability(&self, probability: f64) {
120 let database = self.db.as_ref().unwrap();
121 let test_options = database.test_options.as_ref().unwrap();
122 *test_options.query_failure_probability.lock() = probability;
123 }
124}
125
126#[macro_export]
127macro_rules! test_both_dbs {
128 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
129 #[cfg(target_os = "macos")]
130 #[gpui::test]
131 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
132 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
133 $test_name(test_db.db()).await;
134 }
135
136 #[gpui::test]
137 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
138 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
139 $test_name(test_db.db()).await;
140 }
141 };
142}
143
144impl Drop for TestDb {
145 fn drop(&mut self) {
146 let db = self.db.take().unwrap();
147 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
148 db.test_options.as_ref().unwrap().runtime.block_on(async {
149 use util::ResultExt;
150 let query = "
151 SELECT pg_terminate_backend(pg_stat_activity.pid)
152 FROM pg_stat_activity
153 WHERE
154 pg_stat_activity.datname = current_database() AND
155 pid <> pg_backend_pid();
156 ";
157 db.pool
158 .execute(sea_orm::Statement::from_string(
159 db.pool.get_database_backend(),
160 query,
161 ))
162 .await
163 .log_err();
164 sqlx::Postgres::drop_database(db.options.get_url())
165 .await
166 .log_err();
167 })
168 }
169 }
170}
171
172#[track_caller]
173fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
174 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
175 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
176 pretty_assertions::assert_eq!(expected_channels, actual_channels);
177}
178
179fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
180 use std::collections::HashMap;
181
182 let mut result = Vec::new();
183 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
184
185 for (id, parent_path, name) in channels {
186 let parent_key = parent_path.to_vec();
187 let order = if parent_key.is_empty() {
188 1
189 } else {
190 *order_by_parent
191 .entry(parent_key.clone())
192 .and_modify(|e| *e += 1)
193 .or_insert(1)
194 };
195
196 result.push(Channel {
197 id: *id,
198 name: (*name).to_owned(),
199 visibility: ChannelVisibility::Members,
200 parent_path: parent_key,
201 channel_order: order,
202 });
203 }
204
205 result
206}
207
208static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
209
210async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
211 db.create_user(
212 email,
213 None,
214 false,
215 NewUserParams {
216 github_login: email[0..email.find('@').unwrap()].to_string(),
217 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
218 },
219 )
220 .await
221 .unwrap()
222 .user_id
223}