1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12mod processed_stripe_event_tests;
13
14use crate::migrations::run_database_migrations;
15
16use super::*;
17use gpui::BackgroundExecutor;
18use parking_lot::Mutex;
19use sea_orm::ConnectionTrait;
20use sqlx::migrate::MigrateDatabase;
21use std::sync::{
22 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
23 Arc,
24};
25
26pub struct TestDb {
27 pub db: Option<Arc<Database>>,
28 pub connection: Option<sqlx::AnyConnection>,
29}
30
31impl TestDb {
32 pub fn sqlite(background: BackgroundExecutor) -> Self {
33 let url = "sqlite::memory:";
34 let runtime = tokio::runtime::Builder::new_current_thread()
35 .enable_io()
36 .enable_time()
37 .build()
38 .unwrap();
39
40 let mut db = runtime.block_on(async {
41 let mut options = ConnectOptions::new(url);
42 options.max_connections(5);
43 let mut db = Database::new(options, Executor::Deterministic(background))
44 .await
45 .unwrap();
46 let sql = include_str!(concat!(
47 env!("CARGO_MANIFEST_DIR"),
48 "/migrations.sqlite/20221109000000_test_schema.sql"
49 ));
50 db.pool
51 .execute(sea_orm::Statement::from_string(
52 db.pool.get_database_backend(),
53 sql,
54 ))
55 .await
56 .unwrap();
57 db.initialize_notification_kinds().await.unwrap();
58 db
59 });
60
61 db.runtime = Some(runtime);
62
63 Self {
64 db: Some(Arc::new(db)),
65 connection: None,
66 }
67 }
68
69 pub fn postgres(background: BackgroundExecutor) -> Self {
70 static LOCK: Mutex<()> = Mutex::new(());
71
72 let _guard = LOCK.lock();
73 let mut rng = StdRng::from_entropy();
74 let url = format!(
75 "postgres://postgres@localhost/zed-test-{}",
76 rng.gen::<u128>()
77 );
78 let runtime = tokio::runtime::Builder::new_current_thread()
79 .enable_io()
80 .enable_time()
81 .build()
82 .unwrap();
83
84 let mut db = runtime.block_on(async {
85 sqlx::Postgres::create_database(&url)
86 .await
87 .expect("failed to create test db");
88 let mut options = ConnectOptions::new(url);
89 options
90 .max_connections(5)
91 .idle_timeout(Duration::from_secs(0));
92 let mut db = Database::new(options, Executor::Deterministic(background))
93 .await
94 .unwrap();
95 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
96 run_database_migrations(db.options(), migrations_path)
97 .await
98 .unwrap();
99 db.initialize_notification_kinds().await.unwrap();
100 db
101 });
102
103 db.runtime = Some(runtime);
104
105 Self {
106 db: Some(Arc::new(db)),
107 connection: None,
108 }
109 }
110
111 pub fn db(&self) -> &Arc<Database> {
112 self.db.as_ref().unwrap()
113 }
114}
115
116#[macro_export]
117macro_rules! test_both_dbs {
118 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
119 #[cfg(target_os = "macos")]
120 #[gpui::test]
121 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
122 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
123 $test_name(test_db.db()).await;
124 }
125
126 #[gpui::test]
127 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
128 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
129 $test_name(test_db.db()).await;
130 }
131 };
132}
133
134impl Drop for TestDb {
135 fn drop(&mut self) {
136 let db = self.db.take().unwrap();
137 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
138 db.runtime.as_ref().unwrap().block_on(async {
139 use util::ResultExt;
140 let query = "
141 SELECT pg_terminate_backend(pg_stat_activity.pid)
142 FROM pg_stat_activity
143 WHERE
144 pg_stat_activity.datname = current_database() AND
145 pid <> pg_backend_pid();
146 ";
147 db.pool
148 .execute(sea_orm::Statement::from_string(
149 db.pool.get_database_backend(),
150 query,
151 ))
152 .await
153 .log_err();
154 sqlx::Postgres::drop_database(db.options.get_url())
155 .await
156 .log_err();
157 })
158 }
159 }
160}
161
162fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
163 channels
164 .iter()
165 .map(|(id, parent_path, name)| Channel {
166 id: *id,
167 name: name.to_string(),
168 visibility: ChannelVisibility::Members,
169 parent_path: parent_path.to_vec(),
170 })
171 .collect()
172}
173
174static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
175
176async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
177 db.create_user(
178 email,
179 false,
180 NewUserParams {
181 github_login: email[0..email.find('@').unwrap()].to_string(),
182 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
183 },
184 )
185 .await
186 .unwrap()
187 .user_id
188}
189
190static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
191fn new_test_connection(server: ServerId) -> ConnectionId {
192 ConnectionId {
193 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
194 owner_id: server.0 as u32,
195 }
196}