1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5// we only run postgres tests on macos right now
6#[cfg(target_os = "macos")]
7mod embedding_tests;
8mod extension_tests;
9mod feature_flag_tests;
10mod message_tests;
11
12use super::*;
13use gpui::BackgroundExecutor;
14use parking_lot::Mutex;
15use sea_orm::ConnectionTrait;
16use sqlx::migrate::MigrateDatabase;
17use std::sync::{
18 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
19 Arc,
20};
21
22pub struct TestDb {
23 pub db: Option<Arc<Database>>,
24 pub connection: Option<sqlx::AnyConnection>,
25}
26
27impl TestDb {
28 pub fn sqlite(background: BackgroundExecutor) -> Self {
29 let url = "sqlite::memory:";
30 let runtime = tokio::runtime::Builder::new_current_thread()
31 .enable_io()
32 .enable_time()
33 .build()
34 .unwrap();
35
36 let mut db = runtime.block_on(async {
37 let mut options = ConnectOptions::new(url);
38 options.max_connections(5);
39 let mut db = Database::new(options, Executor::Deterministic(background))
40 .await
41 .unwrap();
42 let sql = include_str!(concat!(
43 env!("CARGO_MANIFEST_DIR"),
44 "/migrations.sqlite/20221109000000_test_schema.sql"
45 ));
46 db.pool
47 .execute(sea_orm::Statement::from_string(
48 db.pool.get_database_backend(),
49 sql,
50 ))
51 .await
52 .unwrap();
53 db.initialize_notification_kinds().await.unwrap();
54 db
55 });
56
57 db.runtime = Some(runtime);
58
59 Self {
60 db: Some(Arc::new(db)),
61 connection: None,
62 }
63 }
64
65 pub fn postgres(background: BackgroundExecutor) -> Self {
66 static LOCK: Mutex<()> = Mutex::new(());
67
68 let _guard = LOCK.lock();
69 let mut rng = StdRng::from_entropy();
70 let url = format!(
71 "postgres://postgres@localhost/zed-test-{}",
72 rng.gen::<u128>()
73 );
74 let runtime = tokio::runtime::Builder::new_current_thread()
75 .enable_io()
76 .enable_time()
77 .build()
78 .unwrap();
79
80 let mut db = runtime.block_on(async {
81 sqlx::Postgres::create_database(&url)
82 .await
83 .expect("failed to create test db");
84 let mut options = ConnectOptions::new(url);
85 options
86 .max_connections(5)
87 .idle_timeout(Duration::from_secs(0));
88 let mut db = Database::new(options, Executor::Deterministic(background))
89 .await
90 .unwrap();
91 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
92 db.migrate(Path::new(migrations_path), false).await.unwrap();
93 db.initialize_notification_kinds().await.unwrap();
94 db
95 });
96
97 db.runtime = Some(runtime);
98
99 Self {
100 db: Some(Arc::new(db)),
101 connection: None,
102 }
103 }
104
105 pub fn db(&self) -> &Arc<Database> {
106 self.db.as_ref().unwrap()
107 }
108}
109
110#[macro_export]
111macro_rules! test_both_dbs {
112 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
113 #[cfg(target_os = "macos")]
114 #[gpui::test]
115 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
116 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
117 $test_name(test_db.db()).await;
118 }
119
120 #[gpui::test]
121 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
122 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
123 $test_name(test_db.db()).await;
124 }
125 };
126}
127
128impl Drop for TestDb {
129 fn drop(&mut self) {
130 let db = self.db.take().unwrap();
131 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
132 db.runtime.as_ref().unwrap().block_on(async {
133 use util::ResultExt;
134 let query = "
135 SELECT pg_terminate_backend(pg_stat_activity.pid)
136 FROM pg_stat_activity
137 WHERE
138 pg_stat_activity.datname = current_database() AND
139 pid <> pg_backend_pid();
140 ";
141 db.pool
142 .execute(sea_orm::Statement::from_string(
143 db.pool.get_database_backend(),
144 query,
145 ))
146 .await
147 .log_err();
148 sqlx::Postgres::drop_database(db.options.get_url())
149 .await
150 .log_err();
151 })
152 }
153 }
154}
155
156fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
157 channels
158 .iter()
159 .map(|(id, parent_path, name)| Channel {
160 id: *id,
161 name: name.to_string(),
162 visibility: ChannelVisibility::Members,
163 parent_path: parent_path.to_vec(),
164 })
165 .collect()
166}
167
168static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
169
170async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
171 db.create_user(
172 email,
173 false,
174 NewUserParams {
175 github_login: email[0..email.find('@').unwrap()].to_string(),
176 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
177 },
178 )
179 .await
180 .unwrap()
181 .user_id
182}
183
184static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
185fn new_test_connection(server: ServerId) -> ConnectionId {
186 ConnectionId {
187 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
188 owner_id: server.0 as u32,
189 }
190}