1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9mod feature_flag_tests;
10mod message_tests;
11mod processed_stripe_event_tests;
12mod user_tests;
13
14use crate::migrations::run_database_migrations;
15
16use super::*;
17use gpui::BackgroundExecutor;
18use parking_lot::Mutex;
19use rand::prelude::*;
20use sea_orm::ConnectionTrait;
21use sqlx::migrate::MigrateDatabase;
22use std::{
23 sync::{
24 Arc,
25 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
26 },
27 time::Duration,
28};
29
30pub struct TestDb {
31 pub db: Option<Arc<Database>>,
32 pub connection: Option<sqlx::AnyConnection>,
33}
34
35impl TestDb {
36 pub fn sqlite(executor: BackgroundExecutor) -> Self {
37 let url = "sqlite::memory:";
38 let runtime = tokio::runtime::Builder::new_current_thread()
39 .enable_io()
40 .enable_time()
41 .build()
42 .unwrap();
43
44 let mut db = runtime.block_on(async {
45 let mut options = ConnectOptions::new(url);
46 options.max_connections(5);
47 let mut db = Database::new(options).await.unwrap();
48 let sql = include_str!(concat!(
49 env!("CARGO_MANIFEST_DIR"),
50 "/migrations.sqlite/20221109000000_test_schema.sql"
51 ));
52 db.pool
53 .execute(sea_orm::Statement::from_string(
54 db.pool.get_database_backend(),
55 sql,
56 ))
57 .await
58 .unwrap();
59 db.initialize_notification_kinds().await.unwrap();
60 db
61 });
62
63 db.test_options = Some(DatabaseTestOptions {
64 executor,
65 runtime,
66 query_failure_probability: parking_lot::Mutex::new(0.0),
67 });
68
69 Self {
70 db: Some(Arc::new(db)),
71 connection: None,
72 }
73 }
74
75 pub fn postgres(executor: BackgroundExecutor) -> Self {
76 static LOCK: Mutex<()> = Mutex::new(());
77
78 let _guard = LOCK.lock();
79 let mut rng = StdRng::from_entropy();
80 let url = format!(
81 "postgres://postgres@localhost/zed-test-{}",
82 rng.r#gen::<u128>()
83 );
84 let runtime = tokio::runtime::Builder::new_current_thread()
85 .enable_io()
86 .enable_time()
87 .build()
88 .unwrap();
89
90 let mut db = runtime.block_on(async {
91 sqlx::Postgres::create_database(&url)
92 .await
93 .expect("failed to create test db");
94 let mut options = ConnectOptions::new(url);
95 options
96 .max_connections(5)
97 .idle_timeout(Duration::from_secs(0));
98 let mut db = Database::new(options).await.unwrap();
99 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
100 run_database_migrations(db.options(), migrations_path)
101 .await
102 .unwrap();
103 db.initialize_notification_kinds().await.unwrap();
104 db
105 });
106
107 db.test_options = Some(DatabaseTestOptions {
108 executor,
109 runtime,
110 query_failure_probability: parking_lot::Mutex::new(0.0),
111 });
112
113 Self {
114 db: Some(Arc::new(db)),
115 connection: None,
116 }
117 }
118
119 pub fn db(&self) -> &Arc<Database> {
120 self.db.as_ref().unwrap()
121 }
122
123 pub fn set_query_failure_probability(&self, probability: f64) {
124 let database = self.db.as_ref().unwrap();
125 let test_options = database.test_options.as_ref().unwrap();
126 *test_options.query_failure_probability.lock() = probability;
127 }
128}
129
130#[macro_export]
131macro_rules! test_both_dbs {
132 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
133 #[cfg(target_os = "macos")]
134 #[gpui::test]
135 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
136 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
137 $test_name(test_db.db()).await;
138 }
139
140 #[gpui::test]
141 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
142 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
143 $test_name(test_db.db()).await;
144 }
145 };
146}
147
148impl Drop for TestDb {
149 fn drop(&mut self) {
150 let db = self.db.take().unwrap();
151 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
152 db.test_options.as_ref().unwrap().runtime.block_on(async {
153 use util::ResultExt;
154 let query = "
155 SELECT pg_terminate_backend(pg_stat_activity.pid)
156 FROM pg_stat_activity
157 WHERE
158 pg_stat_activity.datname = current_database() AND
159 pid <> pg_backend_pid();
160 ";
161 db.pool
162 .execute(sea_orm::Statement::from_string(
163 db.pool.get_database_backend(),
164 query,
165 ))
166 .await
167 .log_err();
168 sqlx::Postgres::drop_database(db.options.get_url())
169 .await
170 .log_err();
171 })
172 }
173 }
174}
175
176#[track_caller]
177fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
178 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
179 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
180 pretty_assertions::assert_eq!(expected_channels, actual_channels);
181}
182
183fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
184 use std::collections::HashMap;
185
186 let mut result = Vec::new();
187 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
188
189 for (id, parent_path, name) in channels {
190 let parent_key = parent_path.to_vec();
191 let order = if parent_key.is_empty() {
192 1
193 } else {
194 *order_by_parent
195 .entry(parent_key.clone())
196 .and_modify(|e| *e += 1)
197 .or_insert(1)
198 };
199
200 result.push(Channel {
201 id: *id,
202 name: name.to_string(),
203 visibility: ChannelVisibility::Members,
204 parent_path: parent_key,
205 channel_order: order,
206 });
207 }
208
209 result
210}
211
212static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
213
214async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
215 db.create_user(
216 email,
217 None,
218 false,
219 NewUserParams {
220 github_login: email[0..email.find('@').unwrap()].to_string(),
221 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
222 },
223 )
224 .await
225 .unwrap()
226 .user_id
227}
228
229static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
230fn new_test_connection(server: ServerId) -> ConnectionId {
231 ConnectionId {
232 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
233 owner_id: server.0 as u32,
234 }
235}