1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13mod user_tests;
14
15use crate::migrations::run_database_migrations;
16
17use super::*;
18use gpui::BackgroundExecutor;
19use parking_lot::Mutex;
20use sea_orm::ConnectionTrait;
21use sqlx::migrate::MigrateDatabase;
22use std::sync::{
23 Arc,
24 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
25};
26
27pub struct TestDb {
28 pub db: Option<Arc<Database>>,
29 pub connection: Option<sqlx::AnyConnection>,
30}
31
32impl TestDb {
33 pub fn sqlite(executor: BackgroundExecutor) -> Self {
34 let url = "sqlite::memory:";
35 let runtime = tokio::runtime::Builder::new_current_thread()
36 .enable_io()
37 .enable_time()
38 .build()
39 .unwrap();
40
41 let mut db = runtime.block_on(async {
42 let mut options = ConnectOptions::new(url);
43 options.max_connections(5);
44 let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
45 .await
46 .unwrap();
47 let sql = include_str!(concat!(
48 env!("CARGO_MANIFEST_DIR"),
49 "/migrations.sqlite/20221109000000_test_schema.sql"
50 ));
51 db.pool
52 .execute(sea_orm::Statement::from_string(
53 db.pool.get_database_backend(),
54 sql,
55 ))
56 .await
57 .unwrap();
58 db.initialize_notification_kinds().await.unwrap();
59 db
60 });
61
62 db.test_options = Some(DatabaseTestOptions {
63 runtime,
64 query_failure_probability: parking_lot::Mutex::new(0.0),
65 });
66
67 Self {
68 db: Some(Arc::new(db)),
69 connection: None,
70 }
71 }
72
73 pub fn postgres(executor: BackgroundExecutor) -> Self {
74 static LOCK: Mutex<()> = Mutex::new(());
75
76 let _guard = LOCK.lock();
77 let mut rng = StdRng::from_entropy();
78 let url = format!(
79 "postgres://postgres@localhost/zed-test-{}",
80 rng.r#gen::<u128>()
81 );
82 let runtime = tokio::runtime::Builder::new_current_thread()
83 .enable_io()
84 .enable_time()
85 .build()
86 .unwrap();
87
88 let mut db = runtime.block_on(async {
89 sqlx::Postgres::create_database(&url)
90 .await
91 .expect("failed to create test db");
92 let mut options = ConnectOptions::new(url);
93 options
94 .max_connections(5)
95 .idle_timeout(Duration::from_secs(0));
96 let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
97 .await
98 .unwrap();
99 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
100 run_database_migrations(db.options(), migrations_path)
101 .await
102 .unwrap();
103 db.initialize_notification_kinds().await.unwrap();
104 db
105 });
106
107 db.test_options = Some(DatabaseTestOptions {
108 runtime,
109 query_failure_probability: parking_lot::Mutex::new(0.0),
110 });
111
112 Self {
113 db: Some(Arc::new(db)),
114 connection: None,
115 }
116 }
117
118 pub fn db(&self) -> &Arc<Database> {
119 self.db.as_ref().unwrap()
120 }
121
122 pub fn set_query_failure_probability(&self, probability: f64) {
123 let database = self.db.as_ref().unwrap();
124 let test_options = database.test_options.as_ref().unwrap();
125 *test_options.query_failure_probability.lock() = probability;
126 }
127}
128
129#[macro_export]
130macro_rules! test_both_dbs {
131 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
132 #[cfg(target_os = "macos")]
133 #[gpui::test]
134 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
135 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
136 $test_name(test_db.db()).await;
137 }
138
139 #[gpui::test]
140 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
141 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
142 $test_name(test_db.db()).await;
143 }
144 };
145}
146
147impl Drop for TestDb {
148 fn drop(&mut self) {
149 let db = self.db.take().unwrap();
150 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
151 db.test_options.as_ref().unwrap().runtime.block_on(async {
152 use util::ResultExt;
153 let query = "
154 SELECT pg_terminate_backend(pg_stat_activity.pid)
155 FROM pg_stat_activity
156 WHERE
157 pg_stat_activity.datname = current_database() AND
158 pid <> pg_backend_pid();
159 ";
160 db.pool
161 .execute(sea_orm::Statement::from_string(
162 db.pool.get_database_backend(),
163 query,
164 ))
165 .await
166 .log_err();
167 sqlx::Postgres::drop_database(db.options.get_url())
168 .await
169 .log_err();
170 })
171 }
172 }
173}
174
175#[track_caller]
176fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
177 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
178 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
179 pretty_assertions::assert_eq!(expected_channels, actual_channels);
180}
181
182fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
183 use std::collections::HashMap;
184
185 let mut result = Vec::new();
186 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
187
188 for (id, parent_path, name) in channels {
189 let parent_key = parent_path.to_vec();
190 let order = if parent_key.is_empty() {
191 1
192 } else {
193 *order_by_parent
194 .entry(parent_key.clone())
195 .and_modify(|e| *e += 1)
196 .or_insert(1)
197 };
198
199 result.push(Channel {
200 id: *id,
201 name: name.to_string(),
202 visibility: ChannelVisibility::Members,
203 parent_path: parent_key,
204 channel_order: order,
205 });
206 }
207
208 result
209}
210
211static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
212
213async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
214 db.create_user(
215 email,
216 None,
217 false,
218 NewUserParams {
219 github_login: email[0..email.find('@').unwrap()].to_string(),
220 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
221 },
222 )
223 .await
224 .unwrap()
225 .user_id
226}
227
228static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
229fn new_test_connection(server: ServerId) -> ConnectionId {
230 ConnectionId {
231 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
232 owner_id: server.0 as u32,
233 }
234}