1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13mod user_tests;
14
15use crate::migrations::run_database_migrations;
16
17use super::*;
18use gpui::BackgroundExecutor;
19use parking_lot::Mutex;
20use rand::prelude::*;
21use sea_orm::ConnectionTrait;
22use sqlx::migrate::MigrateDatabase;
23use std::{
24 sync::{
25 Arc,
26 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
27 },
28 time::Duration,
29};
30
31pub struct TestDb {
32 pub db: Option<Arc<Database>>,
33 pub connection: Option<sqlx::AnyConnection>,
34}
35
36impl TestDb {
37 pub fn sqlite(executor: BackgroundExecutor) -> Self {
38 let url = "sqlite::memory:";
39 let runtime = tokio::runtime::Builder::new_current_thread()
40 .enable_io()
41 .enable_time()
42 .build()
43 .unwrap();
44
45 let mut db = runtime.block_on(async {
46 let mut options = ConnectOptions::new(url);
47 options.max_connections(5);
48 let mut db = Database::new(options).await.unwrap();
49 let sql = include_str!(concat!(
50 env!("CARGO_MANIFEST_DIR"),
51 "/migrations.sqlite/20221109000000_test_schema.sql"
52 ));
53 db.pool
54 .execute(sea_orm::Statement::from_string(
55 db.pool.get_database_backend(),
56 sql,
57 ))
58 .await
59 .unwrap();
60 db.initialize_notification_kinds().await.unwrap();
61 db
62 });
63
64 db.test_options = Some(DatabaseTestOptions {
65 executor,
66 runtime,
67 query_failure_probability: parking_lot::Mutex::new(0.0),
68 });
69
70 Self {
71 db: Some(Arc::new(db)),
72 connection: None,
73 }
74 }
75
76 pub fn postgres(executor: BackgroundExecutor) -> Self {
77 static LOCK: Mutex<()> = Mutex::new(());
78
79 let _guard = LOCK.lock();
80 let mut rng = StdRng::from_entropy();
81 let url = format!(
82 "postgres://postgres@localhost/zed-test-{}",
83 rng.r#gen::<u128>()
84 );
85 let runtime = tokio::runtime::Builder::new_current_thread()
86 .enable_io()
87 .enable_time()
88 .build()
89 .unwrap();
90
91 let mut db = runtime.block_on(async {
92 sqlx::Postgres::create_database(&url)
93 .await
94 .expect("failed to create test db");
95 let mut options = ConnectOptions::new(url);
96 options
97 .max_connections(5)
98 .idle_timeout(Duration::from_secs(0));
99 let mut db = Database::new(options).await.unwrap();
100 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
101 run_database_migrations(db.options(), migrations_path)
102 .await
103 .unwrap();
104 db.initialize_notification_kinds().await.unwrap();
105 db
106 });
107
108 db.test_options = Some(DatabaseTestOptions {
109 executor,
110 runtime,
111 query_failure_probability: parking_lot::Mutex::new(0.0),
112 });
113
114 Self {
115 db: Some(Arc::new(db)),
116 connection: None,
117 }
118 }
119
120 pub fn db(&self) -> &Arc<Database> {
121 self.db.as_ref().unwrap()
122 }
123
124 pub fn set_query_failure_probability(&self, probability: f64) {
125 let database = self.db.as_ref().unwrap();
126 let test_options = database.test_options.as_ref().unwrap();
127 *test_options.query_failure_probability.lock() = probability;
128 }
129}
130
131#[macro_export]
132macro_rules! test_both_dbs {
133 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
134 #[cfg(target_os = "macos")]
135 #[gpui::test]
136 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
137 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
138 $test_name(test_db.db()).await;
139 }
140
141 #[gpui::test]
142 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
143 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
144 $test_name(test_db.db()).await;
145 }
146 };
147}
148
149impl Drop for TestDb {
150 fn drop(&mut self) {
151 let db = self.db.take().unwrap();
152 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
153 db.test_options.as_ref().unwrap().runtime.block_on(async {
154 use util::ResultExt;
155 let query = "
156 SELECT pg_terminate_backend(pg_stat_activity.pid)
157 FROM pg_stat_activity
158 WHERE
159 pg_stat_activity.datname = current_database() AND
160 pid <> pg_backend_pid();
161 ";
162 db.pool
163 .execute(sea_orm::Statement::from_string(
164 db.pool.get_database_backend(),
165 query,
166 ))
167 .await
168 .log_err();
169 sqlx::Postgres::drop_database(db.options.get_url())
170 .await
171 .log_err();
172 })
173 }
174 }
175}
176
177#[track_caller]
178fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
179 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
180 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
181 pretty_assertions::assert_eq!(expected_channels, actual_channels);
182}
183
184fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
185 use std::collections::HashMap;
186
187 let mut result = Vec::new();
188 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
189
190 for (id, parent_path, name) in channels {
191 let parent_key = parent_path.to_vec();
192 let order = if parent_key.is_empty() {
193 1
194 } else {
195 *order_by_parent
196 .entry(parent_key.clone())
197 .and_modify(|e| *e += 1)
198 .or_insert(1)
199 };
200
201 result.push(Channel {
202 id: *id,
203 name: name.to_string(),
204 visibility: ChannelVisibility::Members,
205 parent_path: parent_key,
206 channel_order: order,
207 });
208 }
209
210 result
211}
212
213static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
214
215async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
216 db.create_user(
217 email,
218 None,
219 false,
220 NewUserParams {
221 github_login: email[0..email.find('@').unwrap()].to_string(),
222 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
223 },
224 )
225 .await
226 .unwrap()
227 .user_id
228}
229
230static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
231fn new_test_connection(server: ServerId) -> ConnectionId {
232 ConnectionId {
233 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
234 owner_id: server.0 as u32,
235 }
236}