1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5mod embedding_tests;
6mod extension_tests;
7mod feature_flag_tests;
8mod message_tests;
9
10use super::*;
11use gpui::BackgroundExecutor;
12use parking_lot::Mutex;
13use sea_orm::ConnectionTrait;
14use sqlx::migrate::MigrateDatabase;
15use std::sync::{
16 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
17 Arc,
18};
19
20pub struct TestDb {
21 pub db: Option<Arc<Database>>,
22 pub connection: Option<sqlx::AnyConnection>,
23}
24
25impl TestDb {
26 pub fn sqlite(background: BackgroundExecutor) -> Self {
27 let url = "sqlite::memory:";
28 let runtime = tokio::runtime::Builder::new_current_thread()
29 .enable_io()
30 .enable_time()
31 .build()
32 .unwrap();
33
34 let mut db = runtime.block_on(async {
35 let mut options = ConnectOptions::new(url);
36 options.max_connections(5);
37 let mut db = Database::new(options, Executor::Deterministic(background))
38 .await
39 .unwrap();
40 let sql = include_str!(concat!(
41 env!("CARGO_MANIFEST_DIR"),
42 "/migrations.sqlite/20221109000000_test_schema.sql"
43 ));
44 db.pool
45 .execute(sea_orm::Statement::from_string(
46 db.pool.get_database_backend(),
47 sql,
48 ))
49 .await
50 .unwrap();
51 db.initialize_notification_kinds().await.unwrap();
52 db
53 });
54
55 db.runtime = Some(runtime);
56
57 Self {
58 db: Some(Arc::new(db)),
59 connection: None,
60 }
61 }
62
63 pub fn postgres(background: BackgroundExecutor) -> Self {
64 static LOCK: Mutex<()> = Mutex::new(());
65
66 let _guard = LOCK.lock();
67 let mut rng = StdRng::from_entropy();
68 let url = format!(
69 "postgres://postgres@localhost/zed-test-{}",
70 rng.gen::<u128>()
71 );
72 let runtime = tokio::runtime::Builder::new_current_thread()
73 .enable_io()
74 .enable_time()
75 .build()
76 .unwrap();
77
78 let mut db = runtime.block_on(async {
79 sqlx::Postgres::create_database(&url)
80 .await
81 .expect("failed to create test db");
82 let mut options = ConnectOptions::new(url);
83 options
84 .max_connections(5)
85 .idle_timeout(Duration::from_secs(0));
86 let mut db = Database::new(options, Executor::Deterministic(background))
87 .await
88 .unwrap();
89 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
90 db.migrate(Path::new(migrations_path), false).await.unwrap();
91 db.initialize_notification_kinds().await.unwrap();
92 db
93 });
94
95 db.runtime = Some(runtime);
96
97 Self {
98 db: Some(Arc::new(db)),
99 connection: None,
100 }
101 }
102
103 pub fn db(&self) -> &Arc<Database> {
104 self.db.as_ref().unwrap()
105 }
106}
107
108#[macro_export]
109macro_rules! test_both_dbs {
110 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
111 #[gpui::test]
112 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
113 let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
114 $test_name(test_db.db()).await;
115 }
116
117 #[gpui::test]
118 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
119 let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
120 $test_name(test_db.db()).await;
121 }
122 };
123}
124
125impl Drop for TestDb {
126 fn drop(&mut self) {
127 let db = self.db.take().unwrap();
128 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
129 db.runtime.as_ref().unwrap().block_on(async {
130 use util::ResultExt;
131 let query = "
132 SELECT pg_terminate_backend(pg_stat_activity.pid)
133 FROM pg_stat_activity
134 WHERE
135 pg_stat_activity.datname = current_database() AND
136 pid <> pg_backend_pid();
137 ";
138 db.pool
139 .execute(sea_orm::Statement::from_string(
140 db.pool.get_database_backend(),
141 query,
142 ))
143 .await
144 .log_err();
145 sqlx::Postgres::drop_database(db.options.get_url())
146 .await
147 .log_err();
148 })
149 }
150 }
151}
152
153fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
154 channels
155 .iter()
156 .map(|(id, parent_path, name)| Channel {
157 id: *id,
158 name: name.to_string(),
159 visibility: ChannelVisibility::Members,
160 parent_path: parent_path.to_vec(),
161 })
162 .collect()
163}
164
165static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
166
167async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
168 db.create_user(
169 email,
170 false,
171 NewUserParams {
172 github_login: email[0..email.find('@').unwrap()].to_string(),
173 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
174 },
175 )
176 .await
177 .unwrap()
178 .user_id
179}
180
181static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
182fn new_test_connection(server: ServerId) -> ConnectionId {
183 ConnectionId {
184 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
185 owner_id: server.0 as u32,
186 }
187}