db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7pub use indoc::indoc;
  8pub use lazy_static;
  9use parking_lot::{Mutex, RwLock};
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::fs::create_dir_all;
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use std::time::{SystemTime, UNIX_EPOCH};
 23use util::{async_iife, ResultExt};
 24use util::channel::ReleaseChannel;
 25
 26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 27    PRAGMA synchronous=NORMAL;
 28    PRAGMA busy_timeout=1;
 29    PRAGMA foreign_keys=TRUE;
 30    PRAGMA case_sensitive_like=TRUE;
 31);
 32
 33const DB_INITIALIZE_QUERY: &'static str = sql!(
 34    PRAGMA journal_mode=WAL;
 35);
 36
 37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 38
 39const DB_FILE_NAME: &'static str = "db.sqlite";
 40
 41lazy_static::lazy_static! {
 42    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 43    // static ref DB_WIPED: RwLock<bool> = RwLock::new(false);
 44    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 45    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);    
 46}
 47
 48/// Open or create a database at the given directory path.
 49/// This will retry a couple times if there are failures. If opening fails once, the db directory
 50/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 51/// In either case, static variables are set so that the user can be notified.
 52pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection<M> {
 53    let release_channel_name = release_channel.dev_name();
 54    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 55
 56    // // If WIPE_DB, delete 0-{channel}
 57    // if release_channel == &ReleaseChannel::Dev
 58    //     && wipe_db
 59    //     && !*DB_WIPED.read()
 60    // {
 61    //     let mut db_wiped = DB_WIPED.write();
 62    //     if !*db_wiped {
 63    //         remove_dir_all(&main_db_dir).ok();
 64    //         *db_wiped = true;
 65    //     }
 66    // }
 67
 68    let connection = async_iife!({
 69        // Note: This still has a race condition where 1 set of migrations succeeds
 70        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 71        // This will cause the first connection to have the database taken out 
 72        // from under it. This *should* be fine though. The second dabatase failure will
 73        // cause errors in the log and so should be observed by developers while writing
 74        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 75        // and try again from a blank. As long as running all migrations from start to end 
 76        // on a blank database is ok, this race condition will never be triggered.
 77        //
 78        // Basically: Don't ever push invalid migrations to stable or everyone will have
 79        // a bad time.
 80        
 81        // If no db folder, create one at 0-{channel}
 82        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 83        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 84        
 85        // Optimistically open databases in parallel
 86        if !DB_FILE_OPERATIONS.is_locked() {
 87            // Try building a connection
 88            if let Some(connection) = open_main_db(&db_path).await {
 89                return Ok(connection)
 90            };
 91        }
 92        
 93        // Take a lock in the failure case so that we move the db once per process instead 
 94        // of potentially multiple times from different threads. This shouldn't happen in the
 95        // normal path
 96        let _lock = DB_FILE_OPERATIONS.lock();
 97        if let Some(connection) = open_main_db(&db_path).await {
 98            return Ok(connection)
 99        };
100        
101        let backup_timestamp = SystemTime::now()
102            .duration_since(UNIX_EPOCH)
103            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
104            .as_millis();
105        
106        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
107        let backup_db_dir = db_dir.join(Path::new(&format!(
108            "{}-{}",
109            backup_timestamp,
110            release_channel_name,
111        )));
112
113        std::fs::rename(&main_db_dir, &backup_db_dir)
114            .context("Failed clean up corrupted database, panicking.")?;
115
116        // Set a static ref with the failed timestamp and error so we can notify the user
117        {
118            let mut guard = BACKUP_DB_PATH.write();
119            *guard = Some(backup_db_dir);
120        }
121        
122        // Create a new 0-{channel}
123        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
124        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
125
126        // Try again
127        open_main_db(&db_path).await.context("Could not newly created db")
128    }).await.log_err();
129
130    if let Some(connection) = connection {
131        return connection;
132    }
133   
134    // Set another static ref so that we can escalate the notification
135    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
136    
137    // If still failed, create an in memory db with a known name
138    open_fallback_db().await
139}
140
141async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
142    log::info!("Opening main db");
143    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
144        .with_db_initialization_query(DB_INITIALIZE_QUERY)
145        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
146        .build()
147        .await
148        .log_err()
149}
150
151async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
152    log::info!("Opening fallback db");
153    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
154        .with_db_initialization_query(DB_INITIALIZE_QUERY)
155        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
156        .build()
157        .await
158        .expect(
159            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
160        )
161}
162
163#[cfg(any(test, feature = "test-support"))]
164pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
165    use sqlez::thread_safe_connection::locking_queue;
166
167    ThreadSafeConnection::<M>::builder(db_name, false)
168        .with_db_initialization_query(DB_INITIALIZE_QUERY)
169        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
170        // Serialize queued writes via a mutex and run them synchronously
171        .with_write_queue_constructor(locking_queue())
172        .build()
173        .await
174        .unwrap()
175}
176
177/// Implements a basic DB wrapper for a given domain
178#[macro_export]
179macro_rules! define_connection {
180    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
181        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
182
183        impl ::std::ops::Deref for $t {
184            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
185
186            fn deref(&self) -> &Self::Target {
187                &self.0
188            }
189        }
190        
191        impl $crate::sqlez::domain::Domain for $t {
192            fn name() -> &'static str {
193                stringify!($t)
194            }
195            
196            fn migrations() -> &'static [&'static str] {
197                $migrations
198            } 
199        }
200
201        #[cfg(any(test, feature = "test-support"))]
202        $crate::lazy_static::lazy_static! {
203            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
204        }
205
206        #[cfg(not(any(test, feature = "test-support")))]
207        $crate::lazy_static::lazy_static! {
208            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
209        }
210    };
211    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
212        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
213
214        impl ::std::ops::Deref for $t {
215            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
216
217            fn deref(&self) -> &Self::Target {
218                &self.0
219            }
220        }
221        
222        impl $crate::sqlez::domain::Domain for $t {
223            fn name() -> &'static str {
224                stringify!($t)
225            }
226            
227            fn migrations() -> &'static [&'static str] {
228                $migrations
229            } 
230        }
231
232        #[cfg(any(test, feature = "test-support"))]
233        $crate::lazy_static::lazy_static! {
234            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
235        }
236
237        #[cfg(not(any(test, feature = "test-support")))]
238        $crate::lazy_static::lazy_static! {
239            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
240        }
241    };
242}
243
244#[cfg(test)]
245mod tests {
246    use std::{fs, thread};
247
248    use sqlez::{domain::Domain, connection::Connection};
249    use sqlez_macros::sql;
250    use tempdir::TempDir;
251
252    use crate::{open_db, DB_FILE_NAME};
253    
254    // // Test that wipe_db exists and works and gives a new db
255    // #[gpui::test]
256    // async fn test_wipe_db() {
257    //     enum TestDB {}
258        
259    //     impl Domain for TestDB {
260    //         fn name() -> &'static str {
261    //             "db_tests"
262    //         }
263            
264    //         fn migrations() -> &'static [&'static str] {
265    //             &[sql!(
266    //                 CREATE TABLE test(value);
267    //             )]
268    //         }
269    //     }
270        
271    //     let tempdir = TempDir::new("DbTests").unwrap();
272        
273    //     // Create a db and insert a marker value
274    //     let test_db = open_db::<TestDB>(false, tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
275    //     test_db.write(|connection|  
276    //         connection.exec(sql!(
277    //             INSERT INTO test(value) VALUES (10)
278    //         )).unwrap()().unwrap()
279    //     ).await;
280    //     drop(test_db);
281        
282    //     // Opening db with wipe clears once and removes the marker value
283    //     let mut guards = vec![];
284    //     for _ in 0..5 {
285    //         let path = tempdir.path().to_path_buf();
286    //         let guard = thread::spawn(move || smol::block_on(async {
287    //             let test_db = open_db::<TestDB>(true, &path, &ReleaseChannel::Dev).await;
288                
289    //             assert!(test_db.select_row::<()>(sql!(SELECT value FROM test)).unwrap()().unwrap().is_none())
290    //         }));
291            
292    //         guards.push(guard);
293    //     }
294        
295    //     for guard in guards {
296    //         guard.join().unwrap();
297    //     }
298    // }
299        
300    // Test bad migration panics
301    #[gpui::test]
302    #[should_panic]
303    async fn test_bad_migration_panics() {
304        enum BadDB {}
305        
306        impl Domain for BadDB {
307            fn name() -> &'static str {
308                "db_tests"
309            }
310            
311            fn migrations() -> &'static [&'static str] {
312                &[sql!(CREATE TABLE test(value);),
313                    // failure because test already exists
314                  sql!(CREATE TABLE test(value);)]
315            }
316        }
317       
318        let tempdir = TempDir::new("DbTests").unwrap();
319        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
320    }
321    
322    /// Test that DB exists but corrupted (causing recreate)
323    #[gpui::test]
324    async fn test_db_corruption() {
325        enum CorruptedDB {}
326        
327        impl Domain for CorruptedDB {
328            fn name() -> &'static str {
329                "db_tests"
330            }
331            
332            fn migrations() -> &'static [&'static str] {
333                &[sql!(CREATE TABLE test(value);)]
334            }
335        }
336        
337        enum GoodDB {}
338        
339        impl Domain for GoodDB {
340            fn name() -> &'static str {
341                "db_tests" //Notice same name
342            }
343            
344            fn migrations() -> &'static [&'static str] {
345                &[sql!(CREATE TABLE test2(value);)] //But different migration
346            }
347        }
348       
349        let tempdir = TempDir::new("DbTests").unwrap();
350        {
351            let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
352            assert!(corrupt_db.persistent());
353        }
354        
355        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
356        assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
357        
358        let mut corrupted_backup_dir = fs::read_dir(
359            tempdir.path()
360        ).unwrap().find(|entry| {
361            !entry.as_ref().unwrap().file_name().to_str().unwrap().starts_with("0")
362        }
363        ).unwrap().unwrap().path();
364        corrupted_backup_dir.push(DB_FILE_NAME);
365        
366        dbg!(&corrupted_backup_dir);
367        
368        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
369        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()().unwrap().is_none());
370    }
371    
372    /// Test that DB exists but corrupted (causing recreate)
373    #[gpui::test]
374    async fn test_simultaneous_db_corruption() {
375        enum CorruptedDB {}
376        
377        impl Domain for CorruptedDB {
378            fn name() -> &'static str {
379                "db_tests"
380            }
381            
382            fn migrations() -> &'static [&'static str] {
383                &[sql!(CREATE TABLE test(value);)]
384            }
385        }
386        
387        enum GoodDB {}
388        
389        impl Domain for GoodDB {
390            fn name() -> &'static str {
391                "db_tests" //Notice same name
392            }
393            
394            fn migrations() -> &'static [&'static str] {
395                &[sql!(CREATE TABLE test2(value);)] //But different migration
396            }
397        }
398       
399        let tempdir = TempDir::new("DbTests").unwrap();
400        {
401            let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
402            assert!(corrupt_db.persistent());
403        }
404        
405        let mut guards = vec![];
406        for _ in 0..10 {
407            let tmp_path = tempdir.path().to_path_buf();
408            let guard = thread::spawn(move || {
409                let good_db = smol::block_on(open_db::<GoodDB>(tmp_path.as_path(), &util::channel::ReleaseChannel::Dev));
410                assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
411            });
412            
413            guards.push(guard);
414        
415        }
416        
417       for guard in guards.into_iter() {
418           assert!(guard.join().is_ok());
419       }
420    }
421}