1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13
14use super::*;
15use gpui::BackgroundExecutor;
16use parking_lot::Mutex;
17use sea_orm::ConnectionTrait;
18use sqlx::migrate::MigrateDatabase;
19use std::sync::{
20 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
21 Arc,
22};
23
24pub struct TestDb {
25 pub db: Option<Arc<Database>>,
26 pub connection: Option<sqlx::AnyConnection>,
27}
28
29impl TestDb {
30 pub fn sqlite(background: BackgroundExecutor) -> Self {
31 let url = "sqlite::memory:";
32 let runtime = tokio::runtime::Builder::new_current_thread()
33 .enable_io()
34 .enable_time()
35 .build()
36 .unwrap();
37
38 let mut db = runtime.block_on(async {
39 let mut options = ConnectOptions::new(url);
40 options.max_connections(5);
41 let mut db = Database::new(options, Executor::Deterministic(background))
42 .await
43 .unwrap();
44 let sql = include_str!(concat!(
45 env!("CARGO_MANIFEST_DIR"),
46 "/migrations.sqlite/20221109000000_test_schema.sql"
47 ));
48 db.pool
49 .execute(sea_orm::Statement::from_string(
50 db.pool.get_database_backend(),
51 sql,
52 ))
53 .await
54 .unwrap();
55 db.initialize_notification_kinds().await.unwrap();
56 db
57 });
58
59 db.runtime = Some(runtime);
60
61 Self {
62 db: Some(Arc::new(db)),
63 connection: None,
64 }
65 }
66
67 pub fn postgres(background: BackgroundExecutor) -> Self {
68 static LOCK: Mutex<()> = Mutex::new(());
69
70 let _guard = LOCK.lock();
71 let mut rng = StdRng::from_entropy();
72 let url = format!(
73 "postgres://postgres@localhost/zed-test-{}",
74 rng.gen::<u128>()
75 );
76 let runtime = tokio::runtime::Builder::new_current_thread()
77 .enable_io()
78 .enable_time()
79 .build()
80 .unwrap();
81
82 let mut db = runtime.block_on(async {
83 sqlx::Postgres::create_database(&url)
84 .await
85 .expect("failed to create test db");
86 let mut options = ConnectOptions::new(url);
87 options
88 .max_connections(5)
89 .idle_timeout(Duration::from_secs(0));
90 let mut db = Database::new(options, Executor::Deterministic(background))
91 .await
92 .unwrap();
93 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
94 db.migrate(Path::new(migrations_path), false).await.unwrap();
95 db.initialize_notification_kinds().await.unwrap();
96 db
97 });
98
99 db.runtime = Some(runtime);
100
101 Self {
102 db: Some(Arc::new(db)),
103 connection: None,
104 }
105 }
106
107 pub fn db(&self) -> &Arc<Database> {
108 self.db.as_ref().unwrap()
109 }
110}
111
112#[macro_export]
113macro_rules! test_both_dbs {
114 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
115 #[cfg(target_os = "macos")]
116 #[gpui::test]
117 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
118 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
119 $test_name(test_db.db()).await;
120 }
121
122 #[gpui::test]
123 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
124 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
125 $test_name(test_db.db()).await;
126 }
127 };
128}
129
130impl Drop for TestDb {
131 fn drop(&mut self) {
132 let db = self.db.take().unwrap();
133 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
134 db.runtime.as_ref().unwrap().block_on(async {
135 use util::ResultExt;
136 let query = "
137 SELECT pg_terminate_backend(pg_stat_activity.pid)
138 FROM pg_stat_activity
139 WHERE
140 pg_stat_activity.datname = current_database() AND
141 pid <> pg_backend_pid();
142 ";
143 db.pool
144 .execute(sea_orm::Statement::from_string(
145 db.pool.get_database_backend(),
146 query,
147 ))
148 .await
149 .log_err();
150 sqlx::Postgres::drop_database(db.options.get_url())
151 .await
152 .log_err();
153 })
154 }
155 }
156}
157
158fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
159 channels
160 .iter()
161 .map(|(id, parent_path, name)| Channel {
162 id: *id,
163 name: name.to_string(),
164 visibility: ChannelVisibility::Members,
165 parent_path: parent_path.to_vec(),
166 })
167 .collect()
168}
169
170static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
171
172async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
173 db.create_user(
174 email,
175 false,
176 NewUserParams {
177 github_login: email[0..email.find('@').unwrap()].to_string(),
178 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
179 },
180 )
181 .await
182 .unwrap()
183 .user_id
184}
185
186static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
187fn new_test_connection(server: ServerId) -> ConnectionId {
188 ConnectionId {
189 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
190 owner_id: server.0 as u32,
191 }
192}