1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9mod feature_flag_tests;
10mod message_tests;
11mod user_tests;
12
13use crate::migrations::run_database_migrations;
14
15use super::*;
16use gpui::BackgroundExecutor;
17use parking_lot::Mutex;
18use rand::prelude::*;
19use sea_orm::ConnectionTrait;
20use sqlx::migrate::MigrateDatabase;
21use std::{
22 sync::{
23 Arc,
24 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
25 },
26 time::Duration,
27};
28
29pub struct TestDb {
30 pub db: Option<Arc<Database>>,
31 pub connection: Option<sqlx::AnyConnection>,
32}
33
34impl TestDb {
35 pub fn sqlite(executor: BackgroundExecutor) -> Self {
36 let url = "sqlite::memory:";
37 let runtime = tokio::runtime::Builder::new_current_thread()
38 .enable_io()
39 .enable_time()
40 .build()
41 .unwrap();
42
43 let mut db = runtime.block_on(async {
44 let mut options = ConnectOptions::new(url);
45 options.max_connections(5);
46 let mut db = Database::new(options).await.unwrap();
47 let sql = include_str!(concat!(
48 env!("CARGO_MANIFEST_DIR"),
49 "/migrations.sqlite/20221109000000_test_schema.sql"
50 ));
51 db.pool
52 .execute(sea_orm::Statement::from_string(
53 db.pool.get_database_backend(),
54 sql,
55 ))
56 .await
57 .unwrap();
58 db.initialize_notification_kinds().await.unwrap();
59 db
60 });
61
62 db.test_options = Some(DatabaseTestOptions {
63 executor,
64 runtime,
65 query_failure_probability: parking_lot::Mutex::new(0.0),
66 });
67
68 Self {
69 db: Some(Arc::new(db)),
70 connection: None,
71 }
72 }
73
74 pub fn postgres(executor: BackgroundExecutor) -> Self {
75 static LOCK: Mutex<()> = Mutex::new(());
76
77 let _guard = LOCK.lock();
78 let mut rng = StdRng::from_os_rng();
79 let url = format!(
80 "postgres://postgres@localhost/zed-test-{}",
81 rng.random::<u128>()
82 );
83 let runtime = tokio::runtime::Builder::new_current_thread()
84 .enable_io()
85 .enable_time()
86 .build()
87 .unwrap();
88
89 let mut db = runtime.block_on(async {
90 sqlx::Postgres::create_database(&url)
91 .await
92 .expect("failed to create test db");
93 let mut options = ConnectOptions::new(url);
94 options
95 .max_connections(5)
96 .idle_timeout(Duration::from_secs(0));
97 let mut db = Database::new(options).await.unwrap();
98 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
99 run_database_migrations(db.options(), migrations_path)
100 .await
101 .unwrap();
102 db.initialize_notification_kinds().await.unwrap();
103 db
104 });
105
106 db.test_options = Some(DatabaseTestOptions {
107 executor,
108 runtime,
109 query_failure_probability: parking_lot::Mutex::new(0.0),
110 });
111
112 Self {
113 db: Some(Arc::new(db)),
114 connection: None,
115 }
116 }
117
118 pub fn db(&self) -> &Arc<Database> {
119 self.db.as_ref().unwrap()
120 }
121
122 pub fn set_query_failure_probability(&self, probability: f64) {
123 let database = self.db.as_ref().unwrap();
124 let test_options = database.test_options.as_ref().unwrap();
125 *test_options.query_failure_probability.lock() = probability;
126 }
127}
128
129#[macro_export]
130macro_rules! test_both_dbs {
131 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
132 #[cfg(target_os = "macos")]
133 #[gpui::test]
134 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
135 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
136 $test_name(test_db.db()).await;
137 }
138
139 #[gpui::test]
140 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
141 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
142 $test_name(test_db.db()).await;
143 }
144 };
145}
146
147impl Drop for TestDb {
148 fn drop(&mut self) {
149 let db = self.db.take().unwrap();
150 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
151 db.test_options.as_ref().unwrap().runtime.block_on(async {
152 use util::ResultExt;
153 let query = "
154 SELECT pg_terminate_backend(pg_stat_activity.pid)
155 FROM pg_stat_activity
156 WHERE
157 pg_stat_activity.datname = current_database() AND
158 pid <> pg_backend_pid();
159 ";
160 db.pool
161 .execute(sea_orm::Statement::from_string(
162 db.pool.get_database_backend(),
163 query,
164 ))
165 .await
166 .log_err();
167 sqlx::Postgres::drop_database(db.options.get_url())
168 .await
169 .log_err();
170 })
171 }
172 }
173}
174
175#[track_caller]
176fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
177 let expected_channels = expected.into_iter().collect::<HashSet<_>>();
178 let actual_channels = actual.into_iter().collect::<HashSet<_>>();
179 pretty_assertions::assert_eq!(expected_channels, actual_channels);
180}
181
182fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
183 use std::collections::HashMap;
184
185 let mut result = Vec::new();
186 let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
187
188 for (id, parent_path, name) in channels {
189 let parent_key = parent_path.to_vec();
190 let order = if parent_key.is_empty() {
191 1
192 } else {
193 *order_by_parent
194 .entry(parent_key.clone())
195 .and_modify(|e| *e += 1)
196 .or_insert(1)
197 };
198
199 result.push(Channel {
200 id: *id,
201 name: name.to_string(),
202 visibility: ChannelVisibility::Members,
203 parent_path: parent_key,
204 channel_order: order,
205 });
206 }
207
208 result
209}
210
211static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
212
213async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
214 db.create_user(
215 email,
216 None,
217 false,
218 NewUserParams {
219 github_login: email[0..email.find('@').unwrap()].to_string(),
220 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
221 },
222 )
223 .await
224 .unwrap()
225 .user_id
226}
227
228static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
229fn new_test_connection(server: ServerId) -> ConnectionId {
230 ConnectionId {
231 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
232 owner_id: server.0 as u32,
233 }
234}