1mod buffer_tests;
2mod channel_tests;
3mod db_tests;
4mod feature_flag_tests;
5mod message_tests;
6
7use super::*;
8use gpui::executor::Background;
9use parking_lot::Mutex;
10use sea_orm::ConnectionTrait;
11use sqlx::migrate::MigrateDatabase;
12use std::sync::{
13 atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
14 Arc,
15};
16
17const TEST_RELEASE_CHANNEL: &'static str = "test";
18
19pub struct TestDb {
20 pub db: Option<Arc<Database>>,
21 pub connection: Option<sqlx::AnyConnection>,
22}
23
24impl TestDb {
25 pub fn sqlite(background: Arc<Background>) -> Self {
26 let url = format!("sqlite::memory:");
27 let runtime = tokio::runtime::Builder::new_current_thread()
28 .enable_io()
29 .enable_time()
30 .build()
31 .unwrap();
32
33 let mut db = runtime.block_on(async {
34 let mut options = ConnectOptions::new(url);
35 options.max_connections(5);
36 let mut db = Database::new(options, Executor::Deterministic(background))
37 .await
38 .unwrap();
39 let sql = include_str!(concat!(
40 env!("CARGO_MANIFEST_DIR"),
41 "/migrations.sqlite/20221109000000_test_schema.sql"
42 ));
43 db.pool
44 .execute(sea_orm::Statement::from_string(
45 db.pool.get_database_backend(),
46 sql,
47 ))
48 .await
49 .unwrap();
50 db.initialize_notification_kinds().await.unwrap();
51 db
52 });
53
54 db.runtime = Some(runtime);
55
56 Self {
57 db: Some(Arc::new(db)),
58 connection: None,
59 }
60 }
61
62 pub fn postgres(background: Arc<Background>) -> Self {
63 static LOCK: Mutex<()> = Mutex::new(());
64
65 let _guard = LOCK.lock();
66 let mut rng = StdRng::from_entropy();
67 let url = format!(
68 "postgres://postgres@localhost/zed-test-{}",
69 rng.gen::<u128>()
70 );
71 let runtime = tokio::runtime::Builder::new_current_thread()
72 .enable_io()
73 .enable_time()
74 .build()
75 .unwrap();
76
77 let mut db = runtime.block_on(async {
78 sqlx::Postgres::create_database(&url)
79 .await
80 .expect("failed to create test db");
81 let mut options = ConnectOptions::new(url);
82 options
83 .max_connections(5)
84 .idle_timeout(Duration::from_secs(0));
85 let mut db = Database::new(options, Executor::Deterministic(background))
86 .await
87 .unwrap();
88 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
89 db.migrate(Path::new(migrations_path), false).await.unwrap();
90 db.initialize_notification_kinds().await.unwrap();
91 db
92 });
93
94 db.runtime = Some(runtime);
95
96 Self {
97 db: Some(Arc::new(db)),
98 connection: None,
99 }
100 }
101
102 pub fn db(&self) -> &Arc<Database> {
103 self.db.as_ref().unwrap()
104 }
105}
106
107#[macro_export]
108macro_rules! test_both_dbs {
109 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
110 #[gpui::test]
111 async fn $postgres_test_name() {
112 let test_db = crate::db::TestDb::postgres(
113 gpui::executor::Deterministic::new(0).build_background(),
114 );
115 $test_name(test_db.db()).await;
116 }
117
118 #[gpui::test]
119 async fn $sqlite_test_name() {
120 let test_db =
121 crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
122 $test_name(test_db.db()).await;
123 }
124 };
125}
126
127impl Drop for TestDb {
128 fn drop(&mut self) {
129 let db = self.db.take().unwrap();
130 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
131 db.runtime.as_ref().unwrap().block_on(async {
132 use util::ResultExt;
133 let query = "
134 SELECT pg_terminate_backend(pg_stat_activity.pid)
135 FROM pg_stat_activity
136 WHERE
137 pg_stat_activity.datname = current_database() AND
138 pid <> pg_backend_pid();
139 ";
140 db.pool
141 .execute(sea_orm::Statement::from_string(
142 db.pool.get_database_backend(),
143 query,
144 ))
145 .await
146 .log_err();
147 sqlx::Postgres::drop_database(db.options.get_url())
148 .await
149 .log_err();
150 })
151 }
152 }
153}
154
155fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str, ChannelRole)]) -> Vec<Channel> {
156 channels
157 .iter()
158 .map(|(id, parent_path, name, role)| Channel {
159 id: *id,
160 name: name.to_string(),
161 visibility: ChannelVisibility::Members,
162 role: *role,
163 parent_path: parent_path.to_vec(),
164 })
165 .collect()
166}
167
168static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
169
170async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
171 db.create_user(
172 email,
173 false,
174 NewUserParams {
175 github_login: email[0..email.find("@").unwrap()].to_string(),
176 github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
177 },
178 )
179 .await
180 .unwrap()
181 .user_id
182}
183
184static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
185fn new_test_connection(server: ServerId) -> ConnectionId {
186 ConnectionId {
187 id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
188 owner_id: server.0 as u32,
189 }
190}