1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13mod user_tests;
14
15use crate::migrations::run_database_migrations;
16
17use super::*;
18use gpui::BackgroundExecutor;
19use parking_lot::Mutex;
20use sea_orm::ConnectionTrait;
21use sqlx::migrate::MigrateDatabase;
22use std::sync::{
23 Arc,
24 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
25};
26
27pub struct TestDb {
28 pub db: Option<Arc<Database>>,
29 pub connection: Option<sqlx::AnyConnection>,
30}
31
32impl TestDb {
33 pub fn sqlite(background: BackgroundExecutor) -> Self {
34 let url = "sqlite::memory:";
35 let runtime = tokio::runtime::Builder::new_current_thread()
36 .enable_io()
37 .enable_time()
38 .build()
39 .unwrap();
40
41 let mut db = runtime.block_on(async {
42 let mut options = ConnectOptions::new(url);
43 options.max_connections(5);
44 let mut db = Database::new(options, Executor::Deterministic(background))
45 .await
46 .unwrap();
47 let sql = include_str!(concat!(
48 env!("CARGO_MANIFEST_DIR"),
49 "/migrations.sqlite/20221109000000_test_schema.sql"
50 ));
51 db.pool
52 .execute(sea_orm::Statement::from_string(
53 db.pool.get_database_backend(),
54 sql,
55 ))
56 .await
57 .unwrap();
58 db.initialize_notification_kinds().await.unwrap();
59 db
60 });
61
62 db.runtime = Some(runtime);
63
64 Self {
65 db: Some(Arc::new(db)),
66 connection: None,
67 }
68 }
69
70 pub fn postgres(background: BackgroundExecutor) -> Self {
71 static LOCK: Mutex<()> = Mutex::new(());
72
73 let _guard = LOCK.lock();
74 let mut rng = StdRng::from_entropy();
75 let url = format!(
76 "postgres://postgres@localhost/zed-test-{}",
77 rng.r#gen::<u128>()
78 );
79 let runtime = tokio::runtime::Builder::new_current_thread()
80 .enable_io()
81 .enable_time()
82 .build()
83 .unwrap();
84
85 let mut db = runtime.block_on(async {
86 sqlx::Postgres::create_database(&url)
87 .await
88 .expect("failed to create test db");
89 let mut options = ConnectOptions::new(url);
90 options
91 .max_connections(5)
92 .idle_timeout(Duration::from_secs(0));
93 let mut db = Database::new(options, Executor::Deterministic(background))
94 .await
95 .unwrap();
96 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
97 run_database_migrations(db.options(), migrations_path)
98 .await
99 .unwrap();
100 db.initialize_notification_kinds().await.unwrap();
101 db
102 });
103
104 db.runtime = Some(runtime);
105
106 Self {
107 db: Some(Arc::new(db)),
108 connection: None,
109 }
110 }
111
112 pub fn db(&self) -> &Arc<Database> {
113 self.db.as_ref().unwrap()
114 }
115}
116
117#[macro_export]
118macro_rules! test_both_dbs {
119 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
120 #[cfg(target_os = "macos")]
121 #[gpui::test]
122 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
123 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
124 $test_name(test_db.db()).await;
125 }
126
127 #[gpui::test]
128 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
129 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
130 $test_name(test_db.db()).await;
131 }
132 };
133}
134
135impl Drop for TestDb {
136 fn drop(&mut self) {
137 let db = self.db.take().unwrap();
138 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
139 db.runtime.as_ref().unwrap().block_on(async {
140 use util::ResultExt;
141 let query = "
142 SELECT pg_terminate_backend(pg_stat_activity.pid)
143 FROM pg_stat_activity
144 WHERE
145 pg_stat_activity.datname = current_database() AND
146 pid <> pg_backend_pid();
147 ";
148 db.pool
149 .execute(sea_orm::Statement::from_string(
150 db.pool.get_database_backend(),
151 query,
152 ))
153 .await
154 .log_err();
155 sqlx::Postgres::drop_database(db.options.get_url())
156 .await
157 .log_err();
158 })
159 }
160 }
161}
162
163fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
164 channels
165 .iter()
166 .map(|(id, parent_path, name)| Channel {
167 id: *id,
168 name: name.to_string(),
169 visibility: ChannelVisibility::Members,
170 parent_path: parent_path.to_vec(),
171 })
172 .collect()
173}
174
175static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
176
177async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
178 db.create_user(
179 email,
180 None,
181 false,
182 NewUserParams {
183 github_login: email[0..email.find('@').unwrap()].to_string(),
184 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
185 },
186 )
187 .await
188 .unwrap()
189 .user_id
190}
191
192static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
193fn new_test_connection(server: ServerId) -> ConnectionId {
194 ConnectionId {
195 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
196 owner_id: server.0 as u32,
197 }
198}