1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context;
7pub use indoc::indoc;
8pub use lazy_static;
9use parking_lot::{Mutex, RwLock};
10pub use smol;
11pub use sqlez;
12pub use sqlez_macros;
13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
14pub use util::paths::DB_DIR;
15
16use sqlez::domain::Migrator;
17use sqlez::thread_safe_connection::ThreadSafeConnection;
18use sqlez_macros::sql;
19use std::fs::create_dir_all;
20use std::path::{Path, PathBuf};
21use std::sync::atomic::{AtomicBool, Ordering};
22use std::time::{SystemTime, UNIX_EPOCH};
23use util::{async_iife, ResultExt};
24use util::channel::ReleaseChannel;
25
26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
27 PRAGMA synchronous=NORMAL;
28 PRAGMA busy_timeout=1;
29 PRAGMA foreign_keys=TRUE;
30 PRAGMA case_sensitive_like=TRUE;
31);
32
33const DB_INITIALIZE_QUERY: &'static str = sql!(
34 PRAGMA journal_mode=WAL;
35);
36
37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
38
39const DB_FILE_NAME: &'static str = "db.sqlite";
40
41lazy_static::lazy_static! {
42 static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
43 // static ref DB_WIPED: RwLock<bool> = RwLock::new(false);
44 pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
45 pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
46}
47
48/// Open or create a database at the given directory path.
49/// This will retry a couple times if there are failures. If opening fails once, the db directory
50/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
51/// In either case, static variables are set so that the user can be notified.
52pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection<M> {
53 let release_channel_name = release_channel.dev_name();
54 let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
55
56 // // If WIPE_DB, delete 0-{channel}
57 // if release_channel == &ReleaseChannel::Dev
58 // && wipe_db
59 // && !*DB_WIPED.read()
60 // {
61 // let mut db_wiped = DB_WIPED.write();
62 // if !*db_wiped {
63 // remove_dir_all(&main_db_dir).ok();
64 // *db_wiped = true;
65 // }
66 // }
67
68 let connection = async_iife!({
69 // Note: This still has a race condition where 1 set of migrations succeeds
70 // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
71 // This will cause the first connection to have the database taken out
72 // from under it. This *should* be fine though. The second dabatase failure will
73 // cause errors in the log and so should be observed by developers while writing
74 // soon-to-be good migrations. If user databases are corrupted, we toss them out
75 // and try again from a blank. As long as running all migrations from start to end
76 // on a blank database is ok, this race condition will never be triggered.
77 //
78 // Basically: Don't ever push invalid migrations to stable or everyone will have
79 // a bad time.
80
81 // If no db folder, create one at 0-{channel}
82 create_dir_all(&main_db_dir).context("Could not create db directory")?;
83 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
84
85 // Optimistically open databases in parallel
86 if !DB_FILE_OPERATIONS.is_locked() {
87 // Try building a connection
88 if let Some(connection) = open_main_db(&db_path).await {
89 return Ok(connection)
90 };
91 }
92
93 // Take a lock in the failure case so that we move the db once per process instead
94 // of potentially multiple times from different threads. This shouldn't happen in the
95 // normal path
96 let _lock = DB_FILE_OPERATIONS.lock();
97 if let Some(connection) = open_main_db(&db_path).await {
98 return Ok(connection)
99 };
100
101 let backup_timestamp = SystemTime::now()
102 .duration_since(UNIX_EPOCH)
103 .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
104 .as_millis();
105
106 // If failed, move 0-{channel} to {current unix timestamp}-{channel}
107 let backup_db_dir = db_dir.join(Path::new(&format!(
108 "{}-{}",
109 backup_timestamp,
110 release_channel_name,
111 )));
112
113 std::fs::rename(&main_db_dir, &backup_db_dir)
114 .context("Failed clean up corrupted database, panicking.")?;
115
116 // Set a static ref with the failed timestamp and error so we can notify the user
117 {
118 let mut guard = BACKUP_DB_PATH.write();
119 *guard = Some(backup_db_dir);
120 }
121
122 // Create a new 0-{channel}
123 create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
124 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
125
126 // Try again
127 open_main_db(&db_path).await.context("Could not newly created db")
128 }).await.log_err();
129
130 if let Some(connection) = connection {
131 return connection;
132 }
133
134 // Set another static ref so that we can escalate the notification
135 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
136
137 // If still failed, create an in memory db with a known name
138 open_fallback_db().await
139}
140
141async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
142 log::info!("Opening main db");
143 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
144 .with_db_initialization_query(DB_INITIALIZE_QUERY)
145 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
146 .build()
147 .await
148 .log_err()
149}
150
151async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
152 log::info!("Opening fallback db");
153 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
154 .with_db_initialization_query(DB_INITIALIZE_QUERY)
155 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
156 .build()
157 .await
158 .expect(
159 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
160 )
161}
162
163#[cfg(any(test, feature = "test-support"))]
164pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
165 use sqlez::thread_safe_connection::locking_queue;
166
167 ThreadSafeConnection::<M>::builder(db_name, false)
168 .with_db_initialization_query(DB_INITIALIZE_QUERY)
169 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
170 // Serialize queued writes via a mutex and run them synchronously
171 .with_write_queue_constructor(locking_queue())
172 .build()
173 .await
174 .unwrap()
175}
176
177/// Implements a basic DB wrapper for a given domain
178#[macro_export]
179macro_rules! define_connection {
180 (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
181 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
182
183 impl ::std::ops::Deref for $t {
184 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
185
186 fn deref(&self) -> &Self::Target {
187 &self.0
188 }
189 }
190
191 impl $crate::sqlez::domain::Domain for $t {
192 fn name() -> &'static str {
193 stringify!($t)
194 }
195
196 fn migrations() -> &'static [&'static str] {
197 $migrations
198 }
199 }
200
201 #[cfg(any(test, feature = "test-support"))]
202 $crate::lazy_static::lazy_static! {
203 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
204 }
205
206 #[cfg(not(any(test, feature = "test-support")))]
207 $crate::lazy_static::lazy_static! {
208 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
209 }
210 };
211 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
212 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
213
214 impl ::std::ops::Deref for $t {
215 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220 }
221
222 impl $crate::sqlez::domain::Domain for $t {
223 fn name() -> &'static str {
224 stringify!($t)
225 }
226
227 fn migrations() -> &'static [&'static str] {
228 $migrations
229 }
230 }
231
232 #[cfg(any(test, feature = "test-support"))]
233 $crate::lazy_static::lazy_static! {
234 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
235 }
236
237 #[cfg(not(any(test, feature = "test-support")))]
238 $crate::lazy_static::lazy_static! {
239 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
240 }
241 };
242}
243
244#[cfg(test)]
245mod tests {
246 use std::{fs, thread};
247
248 use sqlez::{domain::Domain, connection::Connection};
249 use sqlez_macros::sql;
250 use tempdir::TempDir;
251
252 use crate::{open_db, DB_FILE_NAME};
253
254 // // Test that wipe_db exists and works and gives a new db
255 // #[gpui::test]
256 // async fn test_wipe_db() {
257 // enum TestDB {}
258
259 // impl Domain for TestDB {
260 // fn name() -> &'static str {
261 // "db_tests"
262 // }
263
264 // fn migrations() -> &'static [&'static str] {
265 // &[sql!(
266 // CREATE TABLE test(value);
267 // )]
268 // }
269 // }
270
271 // let tempdir = TempDir::new("DbTests").unwrap();
272
273 // // Create a db and insert a marker value
274 // let test_db = open_db::<TestDB>(false, tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
275 // test_db.write(|connection|
276 // connection.exec(sql!(
277 // INSERT INTO test(value) VALUES (10)
278 // )).unwrap()().unwrap()
279 // ).await;
280 // drop(test_db);
281
282 // // Opening db with wipe clears once and removes the marker value
283 // let mut guards = vec![];
284 // for _ in 0..5 {
285 // let path = tempdir.path().to_path_buf();
286 // let guard = thread::spawn(move || smol::block_on(async {
287 // let test_db = open_db::<TestDB>(true, &path, &ReleaseChannel::Dev).await;
288
289 // assert!(test_db.select_row::<()>(sql!(SELECT value FROM test)).unwrap()().unwrap().is_none())
290 // }));
291
292 // guards.push(guard);
293 // }
294
295 // for guard in guards {
296 // guard.join().unwrap();
297 // }
298 // }
299
300 // Test bad migration panics
301 #[gpui::test]
302 #[should_panic]
303 async fn test_bad_migration_panics() {
304 enum BadDB {}
305
306 impl Domain for BadDB {
307 fn name() -> &'static str {
308 "db_tests"
309 }
310
311 fn migrations() -> &'static [&'static str] {
312 &[sql!(CREATE TABLE test(value);),
313 // failure because test already exists
314 sql!(CREATE TABLE test(value);)]
315 }
316 }
317
318 let tempdir = TempDir::new("DbTests").unwrap();
319 let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
320 }
321
322 /// Test that DB exists but corrupted (causing recreate)
323 #[gpui::test]
324 async fn test_db_corruption() {
325 enum CorruptedDB {}
326
327 impl Domain for CorruptedDB {
328 fn name() -> &'static str {
329 "db_tests"
330 }
331
332 fn migrations() -> &'static [&'static str] {
333 &[sql!(CREATE TABLE test(value);)]
334 }
335 }
336
337 enum GoodDB {}
338
339 impl Domain for GoodDB {
340 fn name() -> &'static str {
341 "db_tests" //Notice same name
342 }
343
344 fn migrations() -> &'static [&'static str] {
345 &[sql!(CREATE TABLE test2(value);)] //But different migration
346 }
347 }
348
349 let tempdir = TempDir::new("DbTests").unwrap();
350 {
351 let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
352 assert!(corrupt_db.persistent());
353 }
354
355 let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
356 assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
357
358 let mut corrupted_backup_dir = fs::read_dir(
359 tempdir.path()
360 ).unwrap().find(|entry| {
361 !entry.as_ref().unwrap().file_name().to_str().unwrap().starts_with("0")
362 }
363 ).unwrap().unwrap().path();
364 corrupted_backup_dir.push(DB_FILE_NAME);
365
366 dbg!(&corrupted_backup_dir);
367
368 let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
369 assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()().unwrap().is_none());
370 }
371
372 /// Test that DB exists but corrupted (causing recreate)
373 #[gpui::test]
374 async fn test_simultaneous_db_corruption() {
375 enum CorruptedDB {}
376
377 impl Domain for CorruptedDB {
378 fn name() -> &'static str {
379 "db_tests"
380 }
381
382 fn migrations() -> &'static [&'static str] {
383 &[sql!(CREATE TABLE test(value);)]
384 }
385 }
386
387 enum GoodDB {}
388
389 impl Domain for GoodDB {
390 fn name() -> &'static str {
391 "db_tests" //Notice same name
392 }
393
394 fn migrations() -> &'static [&'static str] {
395 &[sql!(CREATE TABLE test2(value);)] //But different migration
396 }
397 }
398
399 let tempdir = TempDir::new("DbTests").unwrap();
400 {
401 let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
402 assert!(corrupt_db.persistent());
403 }
404
405 let mut guards = vec![];
406 for _ in 0..10 {
407 let tmp_path = tempdir.path().to_path_buf();
408 let guard = thread::spawn(move || {
409 let good_db = smol::block_on(open_db::<GoodDB>(tmp_path.as_path(), &util::channel::ReleaseChannel::Dev));
410 assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
411 });
412
413 guards.push(guard);
414
415 }
416
417 for guard in guards.into_iter() {
418 assert!(guard.join().is_ok());
419 }
420 }
421}