1mod buffer_tests;
2mod channel_tests;
3mod contributor_tests;
4mod db_tests;
5mod feature_flag_tests;
6mod message_tests;
7
8use super::*;
9use gpui::BackgroundExecutor;
10use parking_lot::Mutex;
11use sea_orm::ConnectionTrait;
12use sqlx::migrate::MigrateDatabase;
13use std::sync::{
14 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
15 Arc,
16};
17
18pub struct TestDb {
19 pub db: Option<Arc<Database>>,
20 pub connection: Option<sqlx::AnyConnection>,
21}
22
23impl TestDb {
24 pub fn sqlite(background: BackgroundExecutor) -> Self {
25 let url = format!("sqlite::memory:");
26 let runtime = tokio::runtime::Builder::new_current_thread()
27 .enable_io()
28 .enable_time()
29 .build()
30 .unwrap();
31
32 let mut db = runtime.block_on(async {
33 let mut options = ConnectOptions::new(url);
34 options.max_connections(5);
35 let mut db = Database::new(options, Executor::Deterministic(background))
36 .await
37 .unwrap();
38 let sql = include_str!(concat!(
39 env!("CARGO_MANIFEST_DIR"),
40 "/migrations.sqlite/20221109000000_test_schema.sql"
41 ));
42 db.pool
43 .execute(sea_orm::Statement::from_string(
44 db.pool.get_database_backend(),
45 sql,
46 ))
47 .await
48 .unwrap();
49 db.initialize_notification_kinds().await.unwrap();
50 db
51 });
52
53 db.runtime = Some(runtime);
54
55 Self {
56 db: Some(Arc::new(db)),
57 connection: None,
58 }
59 }
60
61 pub fn postgres(background: BackgroundExecutor) -> Self {
62 static LOCK: Mutex<()> = Mutex::new(());
63
64 let _guard = LOCK.lock();
65 let mut rng = StdRng::from_entropy();
66 let url = format!(
67 "postgres://postgres@localhost/zed-test-{}",
68 rng.gen::<u128>()
69 );
70 let runtime = tokio::runtime::Builder::new_current_thread()
71 .enable_io()
72 .enable_time()
73 .build()
74 .unwrap();
75
76 let mut db = runtime.block_on(async {
77 sqlx::Postgres::create_database(&url)
78 .await
79 .expect("failed to create test db");
80 let mut options = ConnectOptions::new(url);
81 options
82 .max_connections(5)
83 .idle_timeout(Duration::from_secs(0));
84 let mut db = Database::new(options, Executor::Deterministic(background))
85 .await
86 .unwrap();
87 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
88 db.migrate(Path::new(migrations_path), false).await.unwrap();
89 db.initialize_notification_kinds().await.unwrap();
90 db
91 });
92
93 db.runtime = Some(runtime);
94
95 Self {
96 db: Some(Arc::new(db)),
97 connection: None,
98 }
99 }
100
101 pub fn db(&self) -> &Arc<Database> {
102 self.db.as_ref().unwrap()
103 }
104}
105
106#[macro_export]
107macro_rules! test_both_dbs {
108 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
109 #[gpui::test]
110 async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
111 let test_db = crate::db::TestDb::postgres(cx.executor().clone());
112 $test_name(test_db.db()).await;
113 }
114
115 #[gpui::test]
116 async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
117 let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
118 $test_name(test_db.db()).await;
119 }
120 };
121}
122
123impl Drop for TestDb {
124 fn drop(&mut self) {
125 let db = self.db.take().unwrap();
126 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
127 db.runtime.as_ref().unwrap().block_on(async {
128 use util::ResultExt;
129 let query = "
130 SELECT pg_terminate_backend(pg_stat_activity.pid)
131 FROM pg_stat_activity
132 WHERE
133 pg_stat_activity.datname = current_database() AND
134 pid <> pg_backend_pid();
135 ";
136 db.pool
137 .execute(sea_orm::Statement::from_string(
138 db.pool.get_database_backend(),
139 query,
140 ))
141 .await
142 .log_err();
143 sqlx::Postgres::drop_database(db.options.get_url())
144 .await
145 .log_err();
146 })
147 }
148 }
149}
150
151fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
152 channels
153 .iter()
154 .map(|(id, parent_path, name)| Channel {
155 id: *id,
156 name: name.to_string(),
157 visibility: ChannelVisibility::Members,
158 parent_path: parent_path.to_vec(),
159 })
160 .collect()
161}
162
163static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
164
165async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
166 db.create_user(
167 email,
168 false,
169 NewUserParams {
170 github_login: email[0..email.find("@").unwrap()].to_string(),
171 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
172 },
173 )
174 .await
175 .unwrap()
176 .user_id
177}
178
179static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
180fn new_test_connection(server: ServerId) -> ConnectionId {
181 ConnectionId {
182 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
183 owner_id: server.0 as u32,
184 }
185}