db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7pub use indoc::indoc;
  8pub use lazy_static;
  9use parking_lot::{Mutex, RwLock};
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::fs::{create_dir_all, remove_dir_all};
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use std::time::{SystemTime, UNIX_EPOCH};
 23use util::{async_iife, ResultExt};
 24use util::channel::ReleaseChannel;
 25
 26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 27    PRAGMA synchronous=NORMAL;
 28    PRAGMA busy_timeout=1;
 29    PRAGMA foreign_keys=TRUE;
 30    PRAGMA case_sensitive_like=TRUE;
 31);
 32
 33const DB_INITIALIZE_QUERY: &'static str = sql!(
 34    PRAGMA journal_mode=WAL;
 35);
 36
 37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 38
 39lazy_static::lazy_static! {
 40    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 41    static ref DB_WIPED: RwLock<bool> = RwLock::new(false);
 42    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 43    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);    
 44}
 45
 46/// Open or create a database at the given directory path.
 47/// This will retry a couple times if there are failures. If opening fails once, the db directory
 48/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 49/// In either case, static variables are set so that the user can be notified.
 50pub async fn open_db<M: Migrator + 'static>(wipe_db: bool, db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection<M> {
 51    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel.name())));
 52
 53    // If WIPE_DB, delete 0-{channel}
 54    if release_channel == &ReleaseChannel::Dev
 55        && wipe_db
 56        && !*DB_WIPED.read()
 57    {
 58        let mut db_wiped = DB_WIPED.write();
 59        if !*db_wiped {
 60            remove_dir_all(&main_db_dir).ok();
 61            *db_wiped = true;
 62        }
 63    }
 64
 65    let connection = async_iife!({
 66        // Note: This still has a race condition where 1 set of migrations succeeds
 67        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 68        // This will cause the first connection to have the database taken out 
 69        // from under it. This *should* be fine though. The second dabatase failure will
 70        // cause errors in the log and so should be observed by developers while writing
 71        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 72        // and try again from a blank. As long as running all migrations from start to end 
 73        // on a blank database is ok, this race condition will never be triggered.
 74        //
 75        // Basically: Don't ever push invalid migrations to stable or everyone will have
 76        // a bad time.
 77        
 78        // If no db folder, create one at 0-{channel}
 79        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 80        let db_path = main_db_dir.join(Path::new("db.sqlite"));
 81        
 82        // Optimistically open databases in parallel
 83        if !DB_FILE_OPERATIONS.is_locked() {
 84            // Try building a connection
 85            if let Some(connection) = open_main_db(&db_path).await {
 86                return Ok(connection)
 87            };
 88        }
 89        
 90        // Take a lock in the failure case so that we move the db once per process instead 
 91        // of potentially multiple times from different threads. This shouldn't happen in the
 92        // normal path
 93        let _lock = DB_FILE_OPERATIONS.lock();
 94        if let Some(connection) = open_main_db(&db_path).await {
 95            return Ok(connection)
 96        };
 97        
 98        let backup_timestamp = SystemTime::now()
 99            .duration_since(UNIX_EPOCH)
100            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
101            .as_millis();
102        
103        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
104        let backup_db_dir = db_dir.join(Path::new(&format!(
105            "{}-{}",
106            backup_timestamp,
107            release_channel.name(),
108        )));
109
110        std::fs::rename(&main_db_dir, &backup_db_dir)
111            .context("Failed clean up corrupted database, panicking.")?;
112
113        // Set a static ref with the failed timestamp and error so we can notify the user
114        {
115            let mut guard = BACKUP_DB_PATH.write();
116            *guard = Some(backup_db_dir);
117        }
118        
119        // Create a new 0-{channel}
120        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
121        let db_path = main_db_dir.join(Path::new("db.sqlite"));
122
123        // Try again
124        open_main_db(&db_path).await.context("Could not newly created db")
125    }).await.log_err();
126
127    if let Some(connection) = connection {
128        return connection;
129    }
130   
131    // Set another static ref so that we can escalate the notification
132    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
133    
134    // If still failed, create an in memory db with a known name
135    open_fallback_db().await
136}
137
138async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
139    log::info!("Opening main db");
140    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
141        .with_db_initialization_query(DB_INITIALIZE_QUERY)
142        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
143        .build()
144        .await
145        .log_err()
146}
147
148async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
149    log::info!("Opening fallback db");
150    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
151        .with_db_initialization_query(DB_INITIALIZE_QUERY)
152        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
153        .build()
154        .await
155        .expect(
156            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
157        )
158}
159
160#[cfg(any(test, feature = "test-support"))]
161pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
162    use sqlez::thread_safe_connection::locking_queue;
163
164    ThreadSafeConnection::<M>::builder(db_name, false)
165        .with_db_initialization_query(DB_INITIALIZE_QUERY)
166        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
167        // Serialize queued writes via a mutex and run them synchronously
168        .with_write_queue_constructor(locking_queue())
169        .build()
170        .await
171        .unwrap()
172}
173
174/// Implements a basic DB wrapper for a given domain
175#[macro_export]
176macro_rules! define_connection {
177    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
178        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
179
180        impl ::std::ops::Deref for $t {
181            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
182
183            fn deref(&self) -> &Self::Target {
184                &self.0
185            }
186        }
187        
188        impl $crate::sqlez::domain::Domain for $t {
189            fn name() -> &'static str {
190                stringify!($t)
191            }
192            
193            fn migrations() -> &'static [&'static str] {
194                $migrations
195            } 
196        }
197
198        #[cfg(any(test, feature = "test-support"))]
199        $crate::lazy_static::lazy_static! {
200            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
201        }
202
203        #[cfg(not(any(test, feature = "test-support")))]
204        $crate::lazy_static::lazy_static! {
205            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
206        }
207    };
208    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
209        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
210
211        impl ::std::ops::Deref for $t {
212            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
213
214            fn deref(&self) -> &Self::Target {
215                &self.0
216            }
217        }
218        
219        impl $crate::sqlez::domain::Domain for $t {
220            fn name() -> &'static str {
221                stringify!($t)
222            }
223            
224            fn migrations() -> &'static [&'static str] {
225                $migrations
226            } 
227        }
228
229        #[cfg(any(test, feature = "test-support"))]
230        $crate::lazy_static::lazy_static! {
231            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
232        }
233
234        #[cfg(not(any(test, feature = "test-support")))]
235        $crate::lazy_static::lazy_static! {
236            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(std::env::var("WIPE_DB").is_ok(), &$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
237        }
238    };
239}
240
241#[cfg(test)]
242mod tests {
243    use std::thread;
244
245    use sqlez::domain::Domain;
246    use sqlez_macros::sql;
247    use tempdir::TempDir;
248    use util::channel::ReleaseChannel;
249
250    use crate::open_db;
251    
252    enum TestDB {}
253    
254    impl Domain for TestDB {
255        fn name() -> &'static str {
256            "db_tests"
257        }
258
259        fn migrations() -> &'static [&'static str] {
260            &[sql!(
261                CREATE TABLE test(value);
262            )]
263        }
264    }
265    
266    // Test that wipe_db exists and works and gives a new db
267    #[test]
268    fn test_wipe_db() {
269        env_logger::try_init().ok();
270        
271        smol::block_on(async {
272            let tempdir = TempDir::new("DbTests").unwrap();
273            
274            let test_db = open_db::<TestDB>(false, tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
275            test_db.write(|connection|  
276                connection.exec(sql!(
277                    INSERT INTO test(value) VALUES (10)
278                )).unwrap()().unwrap()
279            ).await;
280            drop(test_db);
281            
282            let mut guards = vec![];
283            for _ in 0..5 {
284                let path = tempdir.path().to_path_buf();
285                let guard = thread::spawn(move || smol::block_on(async {
286                    let test_db = open_db::<TestDB>(true, &path, &ReleaseChannel::Dev).await;
287                    
288                    assert!(test_db.select_row::<()>(sql!(SELECT value FROM test)).unwrap()().unwrap().is_none())
289                }));
290                
291                guards.push(guard);
292            }
293            
294            for guard in guards {
295                guard.join().unwrap();
296            }
297        })
298    }
299
300    // Test a file system failure (like in create_dir_all())
301    #[test]
302    fn test_file_system_failure() {
303        
304    }
305    
306    // Test happy path where everything exists and opens
307    #[test]
308    fn test_open_db() {
309        
310    }
311    
312    // Test bad migration panics
313    #[test]
314    fn test_bad_migration_panics() {
315        
316    }
317    
318    /// Test that DB exists but corrupted (causing recreate)
319    #[test]
320    fn test_db_corruption() {
321        
322        
323        // open_db(db_dir, release_channel)
324    }
325}