db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7pub use indoc::indoc;
  8pub use lazy_static;
  9use parking_lot::{Mutex, RwLock};
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::fs::create_dir_all;
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use std::time::{SystemTime, UNIX_EPOCH};
 23use util::{async_iife, ResultExt};
 24use util::channel::ReleaseChannel;
 25
 26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 27    PRAGMA foreign_keys=TRUE;
 28);
 29
 30const DB_INITIALIZE_QUERY: &'static str = sql!(
 31    PRAGMA journal_mode=WAL;
 32    PRAGMA busy_timeout=1;
 33    PRAGMA case_sensitive_like=TRUE;
 34    PRAGMA synchronous=NORMAL;
 35);
 36
 37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 38
 39const DB_FILE_NAME: &'static str = "db.sqlite";
 40
 41lazy_static::lazy_static! {
 42    static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
 43    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 44    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 45    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);    
 46}
 47
 48/// Open or create a database at the given directory path.
 49/// This will retry a couple times if there are failures. If opening fails once, the db directory
 50/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 51/// In either case, static variables are set so that the user can be notified.
 52pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, release_channel: &ReleaseChannel) -> ThreadSafeConnection<M> {
 53    if *ZED_STATELESS {
 54        return open_fallback_db().await;
 55    }
 56
 57    let release_channel_name = release_channel.dev_name();
 58    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 59
 60    let connection = async_iife!({
 61        // Note: This still has a race condition where 1 set of migrations succeeds
 62        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 63        // This will cause the first connection to have the database taken out 
 64        // from under it. This *should* be fine though. The second dabatase failure will
 65        // cause errors in the log and so should be observed by developers while writing
 66        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 67        // and try again from a blank. As long as running all migrations from start to end 
 68        // on a blank database is ok, this race condition will never be triggered.
 69        //
 70        // Basically: Don't ever push invalid migrations to stable or everyone will have
 71        // a bad time.
 72        
 73        // If no db folder, create one at 0-{channel}
 74        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 75        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 76        
 77        // Optimistically open databases in parallel
 78        if !DB_FILE_OPERATIONS.is_locked() {
 79            // Try building a connection
 80            if let Some(connection) = open_main_db(&db_path).await {
 81                return Ok(connection)
 82            };
 83        }
 84        
 85        // Take a lock in the failure case so that we move the db once per process instead 
 86        // of potentially multiple times from different threads. This shouldn't happen in the
 87        // normal path
 88        let _lock = DB_FILE_OPERATIONS.lock();
 89        if let Some(connection) = open_main_db(&db_path).await {
 90            return Ok(connection)
 91        };
 92        
 93        let backup_timestamp = SystemTime::now()
 94            .duration_since(UNIX_EPOCH)
 95            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
 96            .as_millis();
 97        
 98        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
 99        let backup_db_dir = db_dir.join(Path::new(&format!(
100            "{}-{}",
101            backup_timestamp,
102            release_channel_name,
103        )));
104
105        std::fs::rename(&main_db_dir, &backup_db_dir)
106            .context("Failed clean up corrupted database, panicking.")?;
107
108        // Set a static ref with the failed timestamp and error so we can notify the user
109        {
110            let mut guard = BACKUP_DB_PATH.write();
111            *guard = Some(backup_db_dir);
112        }
113        
114        // Create a new 0-{channel}
115        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
116        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
117
118        // Try again
119        open_main_db(&db_path).await.context("Could not newly created db")
120    }).await.log_err();
121
122    if let Some(connection) = connection {
123        return connection;
124    }
125   
126    // Set another static ref so that we can escalate the notification
127    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
128    
129    // If still failed, create an in memory db with a known name
130    open_fallback_db().await
131}
132
133async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
134    log::info!("Opening main db");
135    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
136        .with_db_initialization_query(DB_INITIALIZE_QUERY)
137        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
138        .build()
139        .await
140        .log_err()
141}
142
143async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
144    log::info!("Opening fallback db");
145    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
146        .with_db_initialization_query(DB_INITIALIZE_QUERY)
147        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
148        .build()
149        .await
150        .expect(
151            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
152        )
153}
154
155#[cfg(any(test, feature = "test-support"))]
156pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
157    use sqlez::thread_safe_connection::locking_queue;
158
159    ThreadSafeConnection::<M>::builder(db_name, false)
160        .with_db_initialization_query(DB_INITIALIZE_QUERY)
161        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
162        // Serialize queued writes via a mutex and run them synchronously
163        .with_write_queue_constructor(locking_queue())
164        .build()
165        .await
166        .unwrap()
167}
168
169/// Implements a basic DB wrapper for a given domain
170#[macro_export]
171macro_rules! define_connection {
172    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
173        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
174
175        impl ::std::ops::Deref for $t {
176            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
177
178            fn deref(&self) -> &Self::Target {
179                &self.0
180            }
181        }
182        
183        impl $crate::sqlez::domain::Domain for $t {
184            fn name() -> &'static str {
185                stringify!($t)
186            }
187            
188            fn migrations() -> &'static [&'static str] {
189                $migrations
190            } 
191        }
192
193        #[cfg(any(test, feature = "test-support"))]
194        $crate::lazy_static::lazy_static! {
195            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
196        }
197
198        #[cfg(not(any(test, feature = "test-support")))]
199        $crate::lazy_static::lazy_static! {
200            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
201        }
202    };
203    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
204        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
205
206        impl ::std::ops::Deref for $t {
207            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
208
209            fn deref(&self) -> &Self::Target {
210                &self.0
211            }
212        }
213        
214        impl $crate::sqlez::domain::Domain for $t {
215            fn name() -> &'static str {
216                stringify!($t)
217            }
218            
219            fn migrations() -> &'static [&'static str] {
220                $migrations
221            } 
222        }
223
224        #[cfg(any(test, feature = "test-support"))]
225        $crate::lazy_static::lazy_static! {
226            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
227        }
228
229        #[cfg(not(any(test, feature = "test-support")))]
230        $crate::lazy_static::lazy_static! {
231            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
232        }
233    };
234}
235
236#[cfg(test)]
237mod tests {
238    use std::{fs, thread};
239
240    use sqlez::{domain::Domain, connection::Connection};
241    use sqlez_macros::sql;
242    use tempdir::TempDir;
243
244    use crate::{open_db, DB_FILE_NAME};
245        
246    // Test bad migration panics
247    #[gpui::test]
248    #[should_panic]
249    async fn test_bad_migration_panics() {
250        enum BadDB {}
251        
252        impl Domain for BadDB {
253            fn name() -> &'static str {
254                "db_tests"
255            }
256            
257            fn migrations() -> &'static [&'static str] {
258                &[sql!(CREATE TABLE test(value);),
259                    // failure because test already exists
260                  sql!(CREATE TABLE test(value);)]
261            }
262        }
263       
264        let tempdir = TempDir::new("DbTests").unwrap();
265        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
266    }
267    
268    /// Test that DB exists but corrupted (causing recreate)
269    #[gpui::test]
270    async fn test_db_corruption() {
271        enum CorruptedDB {}
272        
273        impl Domain for CorruptedDB {
274            fn name() -> &'static str {
275                "db_tests"
276            }
277            
278            fn migrations() -> &'static [&'static str] {
279                &[sql!(CREATE TABLE test(value);)]
280            }
281        }
282        
283        enum GoodDB {}
284        
285        impl Domain for GoodDB {
286            fn name() -> &'static str {
287                "db_tests" //Notice same name
288            }
289            
290            fn migrations() -> &'static [&'static str] {
291                &[sql!(CREATE TABLE test2(value);)] //But different migration
292            }
293        }
294       
295        let tempdir = TempDir::new("DbTests").unwrap();
296        {
297            let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
298            assert!(corrupt_db.persistent());
299        }
300        
301        
302        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
303        assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
304        
305        let mut corrupted_backup_dir = fs::read_dir(
306            tempdir.path()
307        ).unwrap().find(|entry| {
308            !entry.as_ref().unwrap().file_name().to_str().unwrap().starts_with("0")
309        }
310        ).unwrap().unwrap().path();
311        corrupted_backup_dir.push(DB_FILE_NAME);
312        
313        dbg!(&corrupted_backup_dir);
314        
315        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
316        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()().unwrap().is_none());
317    }
318    
319    /// Test that DB exists but corrupted (causing recreate)
320    #[gpui::test]
321    async fn test_simultaneous_db_corruption() {
322        enum CorruptedDB {}
323        
324        impl Domain for CorruptedDB {
325            fn name() -> &'static str {
326                "db_tests"
327            }
328            
329            fn migrations() -> &'static [&'static str] {
330                &[sql!(CREATE TABLE test(value);)]
331            }
332        }
333        
334        enum GoodDB {}
335        
336        impl Domain for GoodDB {
337            fn name() -> &'static str {
338                "db_tests" //Notice same name
339            }
340            
341            fn migrations() -> &'static [&'static str] {
342                &[sql!(CREATE TABLE test2(value);)] //But different migration
343            }
344        }
345       
346        let tempdir = TempDir::new("DbTests").unwrap();
347        {
348            // Setup the bad database
349            let corrupt_db = open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
350            assert!(corrupt_db.persistent());
351        }
352        
353        // Try to connect to it a bunch of times at once
354        let mut guards = vec![];
355        for _ in 0..10 {
356            let tmp_path = tempdir.path().to_path_buf();
357            let guard = thread::spawn(move || {
358                let good_db = smol::block_on(open_db::<GoodDB>(tmp_path.as_path(), &util::channel::ReleaseChannel::Dev));
359                assert!(good_db.select_row::<usize>("SELECT * FROM test2").unwrap()().unwrap().is_none());
360            });
361            
362            guards.push(guard);
363        
364        }
365        
366       for guard in guards.into_iter() {
367           assert!(guard.join().is_ok());
368       }
369    }
370}