db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7pub use indoc::indoc;
  8pub use lazy_static;
  9use parking_lot::{Mutex, RwLock};
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::fs::create_dir_all;
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use std::time::{SystemTime, UNIX_EPOCH};
 23use util::channel::ReleaseChannel;
 24use util::{async_iife, ResultExt};
 25
 26const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 27    PRAGMA foreign_keys=TRUE;
 28);
 29
 30const DB_INITIALIZE_QUERY: &'static str = sql!(
 31    PRAGMA journal_mode=WAL;
 32    PRAGMA busy_timeout=1;
 33    PRAGMA case_sensitive_like=TRUE;
 34    PRAGMA synchronous=NORMAL;
 35);
 36
 37const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 38
 39const DB_FILE_NAME: &'static str = "db.sqlite";
 40
 41lazy_static::lazy_static! {
 42    // !!!!!!! CHANGE BACK TO DEFAULT FALSE BEFORE SHIPPING
 43    static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(true, |v| !v.is_empty());
 44    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 45    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 46    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
 47}
 48
 49/// Open or create a database at the given directory path.
 50/// This will retry a couple times if there are failures. If opening fails once, the db directory
 51/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 52/// In either case, static variables are set so that the user can be notified.
 53pub async fn open_db<M: Migrator + 'static>(
 54    db_dir: &Path,
 55    release_channel: &ReleaseChannel,
 56) -> ThreadSafeConnection<M> {
 57    if *ZED_STATELESS {
 58        return open_fallback_db().await;
 59    }
 60
 61    let release_channel_name = release_channel.dev_name();
 62    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 63
 64    let connection = async_iife!({
 65        // Note: This still has a race condition where 1 set of migrations succeeds
 66        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 67        // This will cause the first connection to have the database taken out 
 68        // from under it. This *should* be fine though. The second dabatase failure will
 69        // cause errors in the log and so should be observed by developers while writing
 70        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 71        // and try again from a blank. As long as running all migrations from start to end 
 72        // on a blank database is ok, this race condition will never be triggered.
 73        //
 74        // Basically: Don't ever push invalid migrations to stable or everyone will have
 75        // a bad time.
 76
 77        // If no db folder, create one at 0-{channel}
 78        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 79        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 80
 81        // Optimistically open databases in parallel
 82        if !DB_FILE_OPERATIONS.is_locked() {
 83            // Try building a connection
 84            if let Some(connection) = open_main_db(&db_path).await {
 85                return Ok(connection)
 86            };
 87        }
 88
 89        // Take a lock in the failure case so that we move the db once per process instead 
 90        // of potentially multiple times from different threads. This shouldn't happen in the
 91        // normal path
 92        let _lock = DB_FILE_OPERATIONS.lock();
 93        if let Some(connection) = open_main_db(&db_path).await {
 94            return Ok(connection)
 95        };
 96
 97        let backup_timestamp = SystemTime::now()
 98            .duration_since(UNIX_EPOCH)
 99            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
100            .as_millis();
101
102        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
103        let backup_db_dir = db_dir.join(Path::new(&format!(
104            "{}-{}",
105            backup_timestamp,
106            release_channel_name,
107        )));
108
109        std::fs::rename(&main_db_dir, &backup_db_dir)
110            .context("Failed clean up corrupted database, panicking.")?;
111
112        // Set a static ref with the failed timestamp and error so we can notify the user
113        {
114            let mut guard = BACKUP_DB_PATH.write();
115            *guard = Some(backup_db_dir);
116        }
117
118        // Create a new 0-{channel}
119        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
120        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
121
122        // Try again
123        open_main_db(&db_path).await.context("Could not newly created db")
124    }).await.log_err();
125
126    if let Some(connection) = connection {
127        return connection;
128    }
129
130    // Set another static ref so that we can escalate the notification
131    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
132
133    // If still failed, create an in memory db with a known name
134    open_fallback_db().await
135}
136
137async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
138    log::info!("Opening main db");
139    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
140        .with_db_initialization_query(DB_INITIALIZE_QUERY)
141        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
142        .build()
143        .await
144        .log_err()
145}
146
147async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
148    log::info!("Opening fallback db");
149    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
150        .with_db_initialization_query(DB_INITIALIZE_QUERY)
151        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
152        .build()
153        .await
154        .expect(
155            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
156        )
157}
158
159#[cfg(any(test, feature = "test-support"))]
160pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
161    use sqlez::thread_safe_connection::locking_queue;
162
163    ThreadSafeConnection::<M>::builder(db_name, false)
164        .with_db_initialization_query(DB_INITIALIZE_QUERY)
165        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
166        // Serialize queued writes via a mutex and run them synchronously
167        .with_write_queue_constructor(locking_queue())
168        .build()
169        .await
170        .unwrap()
171}
172
173/// Implements a basic DB wrapper for a given domain
174#[macro_export]
175macro_rules! define_connection {
176    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
177        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
178
179        impl ::std::ops::Deref for $t {
180            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
181
182            fn deref(&self) -> &Self::Target {
183                &self.0
184            }
185        }
186
187        impl $crate::sqlez::domain::Domain for $t {
188            fn name() -> &'static str {
189                stringify!($t)
190            }
191
192            fn migrations() -> &'static [&'static str] {
193                $migrations
194            }
195        }
196
197        #[cfg(any(test, feature = "test-support"))]
198        $crate::lazy_static::lazy_static! {
199            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
200        }
201
202        #[cfg(not(any(test, feature = "test-support")))]
203        $crate::lazy_static::lazy_static! {
204            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
205        }
206    };
207    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
208        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
209
210        impl ::std::ops::Deref for $t {
211            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
212
213            fn deref(&self) -> &Self::Target {
214                &self.0
215            }
216        }
217
218        impl $crate::sqlez::domain::Domain for $t {
219            fn name() -> &'static str {
220                stringify!($t)
221            }
222
223            fn migrations() -> &'static [&'static str] {
224                $migrations
225            }
226        }
227
228        #[cfg(any(test, feature = "test-support"))]
229        $crate::lazy_static::lazy_static! {
230            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
231        }
232
233        #[cfg(not(any(test, feature = "test-support")))]
234        $crate::lazy_static::lazy_static! {
235            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
236        }
237    };
238}
239
240#[cfg(test)]
241mod tests {
242    use std::{fs, thread};
243
244    use sqlez::{connection::Connection, domain::Domain};
245    use sqlez_macros::sql;
246    use tempdir::TempDir;
247
248    use crate::{open_db, DB_FILE_NAME};
249
250    // Test bad migration panics
251    #[gpui::test]
252    #[should_panic]
253    async fn test_bad_migration_panics() {
254        enum BadDB {}
255
256        impl Domain for BadDB {
257            fn name() -> &'static str {
258                "db_tests"
259            }
260
261            fn migrations() -> &'static [&'static str] {
262                &[
263                    sql!(CREATE TABLE test(value);),
264                    // failure because test already exists
265                    sql!(CREATE TABLE test(value);),
266                ]
267            }
268        }
269
270        let tempdir = TempDir::new("DbTests").unwrap();
271        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
272    }
273
274    /// Test that DB exists but corrupted (causing recreate)
275    #[gpui::test]
276    async fn test_db_corruption() {
277        enum CorruptedDB {}
278
279        impl Domain for CorruptedDB {
280            fn name() -> &'static str {
281                "db_tests"
282            }
283
284            fn migrations() -> &'static [&'static str] {
285                &[sql!(CREATE TABLE test(value);)]
286            }
287        }
288
289        enum GoodDB {}
290
291        impl Domain for GoodDB {
292            fn name() -> &'static str {
293                "db_tests" //Notice same name
294            }
295
296            fn migrations() -> &'static [&'static str] {
297                &[sql!(CREATE TABLE test2(value);)] //But different migration
298            }
299        }
300
301        let tempdir = TempDir::new("DbTests").unwrap();
302        {
303            let corrupt_db =
304                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
305            assert!(corrupt_db.persistent());
306        }
307
308        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
309        assert!(
310            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
311                .unwrap()
312                .is_none()
313        );
314
315        let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
316            .unwrap()
317            .find(|entry| {
318                !entry
319                    .as_ref()
320                    .unwrap()
321                    .file_name()
322                    .to_str()
323                    .unwrap()
324                    .starts_with("0")
325            })
326            .unwrap()
327            .unwrap()
328            .path();
329        corrupted_backup_dir.push(DB_FILE_NAME);
330
331        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
332        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
333            .unwrap()
334            .is_none());
335    }
336
337    /// Test that DB exists but corrupted (causing recreate)
338    #[gpui::test]
339    async fn test_simultaneous_db_corruption() {
340        enum CorruptedDB {}
341
342        impl Domain for CorruptedDB {
343            fn name() -> &'static str {
344                "db_tests"
345            }
346
347            fn migrations() -> &'static [&'static str] {
348                &[sql!(CREATE TABLE test(value);)]
349            }
350        }
351
352        enum GoodDB {}
353
354        impl Domain for GoodDB {
355            fn name() -> &'static str {
356                "db_tests" //Notice same name
357            }
358
359            fn migrations() -> &'static [&'static str] {
360                &[sql!(CREATE TABLE test2(value);)] //But different migration
361            }
362        }
363
364        let tempdir = TempDir::new("DbTests").unwrap();
365        {
366            // Setup the bad database
367            let corrupt_db =
368                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
369            assert!(corrupt_db.persistent());
370        }
371
372        // Try to connect to it a bunch of times at once
373        let mut guards = vec![];
374        for _ in 0..10 {
375            let tmp_path = tempdir.path().to_path_buf();
376            let guard = thread::spawn(move || {
377                let good_db = smol::block_on(open_db::<GoodDB>(
378                    tmp_path.as_path(),
379                    &util::channel::ReleaseChannel::Dev,
380                ));
381                assert!(
382                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
383                        .unwrap()
384                        .is_none()
385                );
386            });
387
388            guards.push(guard);
389        }
390
391        for guard in guards.into_iter() {
392            assert!(guard.join().is_ok());
393        }
394    }
395}