db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7pub use indoc::indoc;
  8pub use lazy_static;
  9use parking_lot::{Mutex, RwLock};
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::fs::create_dir_all;
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use std::time::{SystemTime, UNIX_EPOCH};
 23use util::channel::ReleaseChannel;
 24use util::{async_iife, ResultExt};
 25
 26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 27    PRAGMA foreign_keys=TRUE;
 28);
 29
 30const DB_INITIALIZE_QUERY: &'static str = sql!(
 31    PRAGMA journal_mode=WAL;
 32    PRAGMA busy_timeout=1;
 33    PRAGMA case_sensitive_like=TRUE;
 34    PRAGMA synchronous=NORMAL;
 35);
 36
 37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 38
 39const DB_FILE_NAME: &'static str = "db.sqlite";
 40
 41lazy_static::lazy_static! {
 42    static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
 43    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 44    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 45    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
 46}
 47
 48/// Open or create a database at the given directory path.
 49/// This will retry a couple times if there are failures. If opening fails once, the db directory
 50/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 51/// In either case, static variables are set so that the user can be notified.
 52pub async fn open_db<M: Migrator + 'static>(
 53    db_dir: &Path,
 54    release_channel: &ReleaseChannel,
 55) -> ThreadSafeConnection<M> {
 56    if *ZED_STATELESS {
 57        return open_fallback_db().await;
 58    }
 59
 60    let release_channel_name = release_channel.dev_name();
 61    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 62
 63    let connection = async_iife!({
 64        // Note: This still has a race condition where 1 set of migrations succeeds
 65        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 66        // This will cause the first connection to have the database taken out 
 67        // from under it. This *should* be fine though. The second dabatase failure will
 68        // cause errors in the log and so should be observed by developers while writing
 69        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 70        // and try again from a blank. As long as running all migrations from start to end 
 71        // on a blank database is ok, this race condition will never be triggered.
 72        //
 73        // Basically: Don't ever push invalid migrations to stable or everyone will have
 74        // a bad time.
 75
 76        // If no db folder, create one at 0-{channel}
 77        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 78        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 79
 80        // Optimistically open databases in parallel
 81        if !DB_FILE_OPERATIONS.is_locked() {
 82            // Try building a connection
 83            if let Some(connection) = open_main_db(&db_path).await {
 84                return Ok(connection)
 85            };
 86        }
 87
 88        // Take a lock in the failure case so that we move the db once per process instead 
 89        // of potentially multiple times from different threads. This shouldn't happen in the
 90        // normal path
 91        let _lock = DB_FILE_OPERATIONS.lock();
 92        if let Some(connection) = open_main_db(&db_path).await {
 93            return Ok(connection)
 94        };
 95
 96        let backup_timestamp = SystemTime::now()
 97            .duration_since(UNIX_EPOCH)
 98            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
 99            .as_millis();
100
101        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
102        let backup_db_dir = db_dir.join(Path::new(&format!(
103            "{}-{}",
104            backup_timestamp,
105            release_channel_name,
106        )));
107
108        std::fs::rename(&main_db_dir, &backup_db_dir)
109            .context("Failed clean up corrupted database, panicking.")?;
110
111        // Set a static ref with the failed timestamp and error so we can notify the user
112        {
113            let mut guard = BACKUP_DB_PATH.write();
114            *guard = Some(backup_db_dir);
115        }
116
117        // Create a new 0-{channel}
118        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
119        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
120
121        // Try again
122        open_main_db(&db_path).await.context("Could not newly created db")
123    }).await.log_err();
124
125    if let Some(connection) = connection {
126        return connection;
127    }
128
129    // Set another static ref so that we can escalate the notification
130    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
131
132    // If still failed, create an in memory db with a known name
133    open_fallback_db().await
134}
135
136async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
137    log::info!("Opening main db");
138    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
139        .with_db_initialization_query(DB_INITIALIZE_QUERY)
140        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
141        .build()
142        .await
143        .log_err()
144}
145
146async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
147    log::info!("Opening fallback db");
148    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
149        .with_db_initialization_query(DB_INITIALIZE_QUERY)
150        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
151        .build()
152        .await
153        .expect(
154            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
155        )
156}
157
158#[cfg(any(test, feature = "test-support"))]
159pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
160    use sqlez::thread_safe_connection::locking_queue;
161
162    ThreadSafeConnection::<M>::builder(db_name, false)
163        .with_db_initialization_query(DB_INITIALIZE_QUERY)
164        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
165        // Serialize queued writes via a mutex and run them synchronously
166        .with_write_queue_constructor(locking_queue())
167        .build()
168        .await
169        .unwrap()
170}
171
172/// Implements a basic DB wrapper for a given domain
173#[macro_export]
174macro_rules! define_connection {
175    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
176        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
177
178        impl ::std::ops::Deref for $t {
179            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
180
181            fn deref(&self) -> &Self::Target {
182                &self.0
183            }
184        }
185
186        impl $crate::sqlez::domain::Domain for $t {
187            fn name() -> &'static str {
188                stringify!($t)
189            }
190
191            fn migrations() -> &'static [&'static str] {
192                $migrations
193            }
194        }
195
196        #[cfg(any(test, feature = "test-support"))]
197        $crate::lazy_static::lazy_static! {
198            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
199        }
200
201        #[cfg(not(any(test, feature = "test-support")))]
202        $crate::lazy_static::lazy_static! {
203            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
204        }
205    };
206    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
207        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
208
209        impl ::std::ops::Deref for $t {
210            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
211
212            fn deref(&self) -> &Self::Target {
213                &self.0
214            }
215        }
216
217        impl $crate::sqlez::domain::Domain for $t {
218            fn name() -> &'static str {
219                stringify!($t)
220            }
221
222            fn migrations() -> &'static [&'static str] {
223                $migrations
224            }
225        }
226
227        #[cfg(any(test, feature = "test-support"))]
228        $crate::lazy_static::lazy_static! {
229            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
230        }
231
232        #[cfg(not(any(test, feature = "test-support")))]
233        $crate::lazy_static::lazy_static! {
234            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
235        }
236    };
237}
238
239#[cfg(test)]
240mod tests {
241    use std::{fs, thread};
242
243    use sqlez::{connection::Connection, domain::Domain};
244    use sqlez_macros::sql;
245    use tempdir::TempDir;
246
247    use crate::{open_db, DB_FILE_NAME};
248
249    // Test bad migration panics
250    #[gpui::test]
251    #[should_panic]
252    async fn test_bad_migration_panics() {
253        enum BadDB {}
254
255        impl Domain for BadDB {
256            fn name() -> &'static str {
257                "db_tests"
258            }
259
260            fn migrations() -> &'static [&'static str] {
261                &[
262                    sql!(CREATE TABLE test(value);),
263                    // failure because test already exists
264                    sql!(CREATE TABLE test(value);),
265                ]
266            }
267        }
268
269        let tempdir = TempDir::new("DbTests").unwrap();
270        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
271    }
272
273    /// Test that DB exists but corrupted (causing recreate)
274    #[gpui::test]
275    async fn test_db_corruption() {
276        enum CorruptedDB {}
277
278        impl Domain for CorruptedDB {
279            fn name() -> &'static str {
280                "db_tests"
281            }
282
283            fn migrations() -> &'static [&'static str] {
284                &[sql!(CREATE TABLE test(value);)]
285            }
286        }
287
288        enum GoodDB {}
289
290        impl Domain for GoodDB {
291            fn name() -> &'static str {
292                "db_tests" //Notice same name
293            }
294
295            fn migrations() -> &'static [&'static str] {
296                &[sql!(CREATE TABLE test2(value);)] //But different migration
297            }
298        }
299
300        let tempdir = TempDir::new("DbTests").unwrap();
301        {
302            let corrupt_db =
303                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
304            assert!(corrupt_db.persistent());
305        }
306
307        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
308        assert!(
309            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
310                .unwrap()
311                .is_none()
312        );
313
314        let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
315            .unwrap()
316            .find(|entry| {
317                !entry
318                    .as_ref()
319                    .unwrap()
320                    .file_name()
321                    .to_str()
322                    .unwrap()
323                    .starts_with("0")
324            })
325            .unwrap()
326            .unwrap()
327            .path();
328        corrupted_backup_dir.push(DB_FILE_NAME);
329
330        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
331        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
332            .unwrap()
333            .is_none());
334    }
335
336    /// Test that DB exists but corrupted (causing recreate)
337    #[gpui::test]
338    async fn test_simultaneous_db_corruption() {
339        enum CorruptedDB {}
340
341        impl Domain for CorruptedDB {
342            fn name() -> &'static str {
343                "db_tests"
344            }
345
346            fn migrations() -> &'static [&'static str] {
347                &[sql!(CREATE TABLE test(value);)]
348            }
349        }
350
351        enum GoodDB {}
352
353        impl Domain for GoodDB {
354            fn name() -> &'static str {
355                "db_tests" //Notice same name
356            }
357
358            fn migrations() -> &'static [&'static str] {
359                &[sql!(CREATE TABLE test2(value);)] //But different migration
360            }
361        }
362
363        let tempdir = TempDir::new("DbTests").unwrap();
364        {
365            // Setup the bad database
366            let corrupt_db =
367                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
368            assert!(corrupt_db.persistent());
369        }
370
371        // Try to connect to it a bunch of times at once
372        let mut guards = vec![];
373        for _ in 0..10 {
374            let tmp_path = tempdir.path().to_path_buf();
375            let guard = thread::spawn(move || {
376                let good_db = smol::block_on(open_db::<GoodDB>(
377                    tmp_path.as_path(),
378                    &util::channel::ReleaseChannel::Dev,
379                ));
380                assert!(
381                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
382                        .unwrap()
383                        .is_none()
384                );
385            });
386
387            guards.push(guard);
388        }
389
390        for guard in guards.into_iter() {
391            assert!(guard.join().is_ok());
392        }
393    }
394}