1mod billing_subscription_tests;
2mod buffer_tests;
3mod channel_tests;
4mod contributor_tests;
5mod db_tests;
6// we only run postgres tests on macos right now
7#[cfg(target_os = "macos")]
8mod embedding_tests;
9mod extension_tests;
10mod feature_flag_tests;
11mod message_tests;
12
13use super::*;
14use gpui::BackgroundExecutor;
15use parking_lot::Mutex;
16use sea_orm::ConnectionTrait;
17use sqlx::migrate::MigrateDatabase;
18use std::sync::{
19 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
20 Arc,
21};
22
23pub struct TestDb {
24 pub db: Option<Arc<Database>>,
25 pub connection: Option<sqlx::AnyConnection>,
26}
27
28impl TestDb {
29 pub fn sqlite(background: BackgroundExecutor) -> Self {
30 let url = "sqlite::memory:";
31 let runtime = tokio::runtime::Builder::new_current_thread()
32 .enable_io()
33 .enable_time()
34 .build()
35 .unwrap();
36
37 let mut db = runtime.block_on(async {
38 let mut options = ConnectOptions::new(url);
39 options.max_connections(5);
40 let mut db = Database::new(options, Executor::Deterministic(background))
41 .await
42 .unwrap();
43 let sql = include_str!(concat!(
44 env!("CARGO_MANIFEST_DIR"),
45 "/migrations.sqlite/20221109000000_test_schema.sql"
46 ));
47 db.pool
48 .execute(sea_orm::Statement::from_string(
49 db.pool.get_database_backend(),
50 sql,
51 ))
52 .await
53 .unwrap();
54 db.initialize_notification_kinds().await.unwrap();
55 db
56 });
57
58 db.runtime = Some(runtime);
59
60 Self {
61 db: Some(Arc::new(db)),
62 connection: None,
63 }
64 }
65
66 pub fn postgres(background: BackgroundExecutor) -> Self {
67 static LOCK: Mutex<()> = Mutex::new(());
68
69 let _guard = LOCK.lock();
70 let mut rng = StdRng::from_entropy();
71 let url = format!(
72 "postgres://postgres@localhost/zed-test-{}",
73 rng.gen::<u128>()
74 );
75 let runtime = tokio::runtime::Builder::new_current_thread()
76 .enable_io()
77 .enable_time()
78 .build()
79 .unwrap();
80
81 let mut db = runtime.block_on(async {
82 sqlx::Postgres::create_database(&url)
83 .await
84 .expect("failed to create test db");
85 let mut options = ConnectOptions::new(url);
86 options
87 .max_connections(5)
88 .idle_timeout(Duration::from_secs(0));
89 let mut db = Database::new(options, Executor::Deterministic(background))
90 .await
91 .unwrap();
92 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
93 db.migrate(Path::new(migrations_path), false).await.unwrap();
94 db.initialize_notification_kinds().await.unwrap();
95 db
96 });
97
98 db.runtime = Some(runtime);
99
100 Self {
101 db: Some(Arc::new(db)),
102 connection: None,
103 }
104 }
105
106 pub fn db(&self) -> &Arc<Database> {
107 self.db.as_ref().unwrap()
108 }
109}
110
111#[macro_export]
112macro_rules! test_both_dbs {
113 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
114 #[cfg(target_os = "macos")]
115 #[gpui::test]
116 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
117 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
118 $test_name(test_db.db()).await;
119 }
120
121 #[gpui::test]
122 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
123 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
124 $test_name(test_db.db()).await;
125 }
126 };
127}
128
129impl Drop for TestDb {
130 fn drop(&mut self) {
131 let db = self.db.take().unwrap();
132 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
133 db.runtime.as_ref().unwrap().block_on(async {
134 use util::ResultExt;
135 let query = "
136 SELECT pg_terminate_backend(pg_stat_activity.pid)
137 FROM pg_stat_activity
138 WHERE
139 pg_stat_activity.datname = current_database() AND
140 pid <> pg_backend_pid();
141 ";
142 db.pool
143 .execute(sea_orm::Statement::from_string(
144 db.pool.get_database_backend(),
145 query,
146 ))
147 .await
148 .log_err();
149 sqlx::Postgres::drop_database(db.options.get_url())
150 .await
151 .log_err();
152 })
153 }
154 }
155}
156
157fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
158 channels
159 .iter()
160 .map(|(id, parent_path, name)| Channel {
161 id: *id,
162 name: name.to_string(),
163 visibility: ChannelVisibility::Members,
164 parent_path: parent_path.to_vec(),
165 })
166 .collect()
167}
168
169static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
170
171async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
172 db.create_user(
173 email,
174 false,
175 NewUserParams {
176 github_login: email[0..email.find('@').unwrap()].to_string(),
177 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
178 },
179 )
180 .await
181 .unwrap()
182 .user_id
183}
184
185static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
186fn new_test_connection(server: ServerId) -> ConnectionId {
187 ConnectionId {
188 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
189 owner_id: server.0 as u32,
190 }
191}