db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7use gpui::AppContext;
  8pub use indoc::indoc;
  9pub use lazy_static;
 10use parking_lot::{Mutex, RwLock};
 11pub use smol;
 12pub use sqlez;
 13pub use sqlez_macros;
 14pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 15pub use util::paths::DB_DIR;
 16
 17use sqlez::domain::Migrator;
 18use sqlez::thread_safe_connection::ThreadSafeConnection;
 19use sqlez_macros::sql;
 20use std::fs::create_dir_all;
 21use std::future::Future;
 22use std::path::{Path, PathBuf};
 23use std::sync::atomic::{AtomicBool, Ordering};
 24use std::time::{SystemTime, UNIX_EPOCH};
 25use util::channel::ReleaseChannel;
 26use util::{async_iife, ResultExt};
 27
 28const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 29    PRAGMA foreign_keys=TRUE;
 30);
 31
 32const DB_INITIALIZE_QUERY: &'static str = sql!(
 33    PRAGMA journal_mode=WAL;
 34    PRAGMA busy_timeout=1;
 35    PRAGMA case_sensitive_like=TRUE;
 36    PRAGMA synchronous=NORMAL;
 37);
 38
 39const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 40
 41const DB_FILE_NAME: &'static str = "db.sqlite";
 42
 43lazy_static::lazy_static! {
 44    // !!!!!!! CHANGE BACK TO DEFAULT FALSE BEFORE SHIPPING
 45    static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
 46    static ref DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 47    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 48    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
 49}
 50
 51/// Open or create a database at the given directory path.
 52/// This will retry a couple times if there are failures. If opening fails once, the db directory
 53/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 54/// In either case, static variables are set so that the user can be notified.
 55pub async fn open_db<M: Migrator + 'static>(
 56    db_dir: &Path,
 57    release_channel: &ReleaseChannel,
 58) -> ThreadSafeConnection<M> {
 59    if *ZED_STATELESS {
 60        return open_fallback_db().await;
 61    }
 62
 63    let release_channel_name = release_channel.dev_name();
 64    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 65
 66    let connection = async_iife!({
 67        // Note: This still has a race condition where 1 set of migrations succeeds
 68        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 69        // This will cause the first connection to have the database taken out
 70        // from under it. This *should* be fine though. The second dabatase failure will
 71        // cause errors in the log and so should be observed by developers while writing
 72        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 73        // and try again from a blank. As long as running all migrations from start to end
 74        // on a blank database is ok, this race condition will never be triggered.
 75        //
 76        // Basically: Don't ever push invalid migrations to stable or everyone will have
 77        // a bad time.
 78
 79        // If no db folder, create one at 0-{channel}
 80        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 81        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 82
 83        // Optimistically open databases in parallel
 84        if !DB_FILE_OPERATIONS.is_locked() {
 85            // Try building a connection
 86            if let Some(connection) = open_main_db(&db_path).await {
 87                return Ok(connection)
 88            };
 89        }
 90
 91        // Take a lock in the failure case so that we move the db once per process instead
 92        // of potentially multiple times from different threads. This shouldn't happen in the
 93        // normal path
 94        let _lock = DB_FILE_OPERATIONS.lock();
 95        if let Some(connection) = open_main_db(&db_path).await {
 96            return Ok(connection)
 97        };
 98
 99        let backup_timestamp = SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
102            .as_millis();
103
104        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
105        let backup_db_dir = db_dir.join(Path::new(&format!(
106            "{}-{}",
107            backup_timestamp,
108            release_channel_name,
109        )));
110
111        std::fs::rename(&main_db_dir, &backup_db_dir)
112            .context("Failed clean up corrupted database, panicking.")?;
113
114        // Set a static ref with the failed timestamp and error so we can notify the user
115        {
116            let mut guard = BACKUP_DB_PATH.write();
117            *guard = Some(backup_db_dir);
118        }
119
120        // Create a new 0-{channel}
121        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
122        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
123
124        // Try again
125        open_main_db(&db_path).await.context("Could not newly created db")
126    }).await.log_err();
127
128    if let Some(connection) = connection {
129        return connection;
130    }
131
132    // Set another static ref so that we can escalate the notification
133    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
134
135    // If still failed, create an in memory db with a known name
136    open_fallback_db().await
137}
138
139async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
140    log::info!("Opening main db");
141    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
142        .with_db_initialization_query(DB_INITIALIZE_QUERY)
143        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
144        .build()
145        .await
146        .log_err()
147}
148
149async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
150    log::info!("Opening fallback db");
151    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
152        .with_db_initialization_query(DB_INITIALIZE_QUERY)
153        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
154        .build()
155        .await
156        .expect(
157            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
158        )
159}
160
161#[cfg(any(test, feature = "test-support"))]
162pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
163    use sqlez::thread_safe_connection::locking_queue;
164
165    ThreadSafeConnection::<M>::builder(db_name, false)
166        .with_db_initialization_query(DB_INITIALIZE_QUERY)
167        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
168        // Serialize queued writes via a mutex and run them synchronously
169        .with_write_queue_constructor(locking_queue())
170        .build()
171        .await
172        .unwrap()
173}
174
175/// Implements a basic DB wrapper for a given domain
176#[macro_export]
177macro_rules! define_connection {
178    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
179        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
180
181        impl ::std::ops::Deref for $t {
182            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
183
184            fn deref(&self) -> &Self::Target {
185                &self.0
186            }
187        }
188
189        impl $crate::sqlez::domain::Domain for $t {
190            fn name() -> &'static str {
191                stringify!($t)
192            }
193
194            fn migrations() -> &'static [&'static str] {
195                $migrations
196            }
197        }
198
199        #[cfg(any(test, feature = "test-support"))]
200        $crate::lazy_static::lazy_static! {
201            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
202        }
203
204        #[cfg(not(any(test, feature = "test-support")))]
205        $crate::lazy_static::lazy_static! {
206            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
207        }
208    };
209    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
210        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
211
212        impl ::std::ops::Deref for $t {
213            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
214
215            fn deref(&self) -> &Self::Target {
216                &self.0
217            }
218        }
219
220        impl $crate::sqlez::domain::Domain for $t {
221            fn name() -> &'static str {
222                stringify!($t)
223            }
224
225            fn migrations() -> &'static [&'static str] {
226                $migrations
227            }
228        }
229
230        #[cfg(any(test, feature = "test-support"))]
231        $crate::lazy_static::lazy_static! {
232            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
233        }
234
235        #[cfg(not(any(test, feature = "test-support")))]
236        $crate::lazy_static::lazy_static! {
237            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
238        }
239    };
240}
241
242pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
243where
244    F: Future<Output = anyhow::Result<()>> + Send,
245{
246    cx.background()
247        .spawn(async move { db_write().await.log_err() })
248        .detach()
249}
250
251#[cfg(test)]
252mod tests {
253    use std::{fs, thread};
254
255    use sqlez::{connection::Connection, domain::Domain};
256    use sqlez_macros::sql;
257    use tempdir::TempDir;
258
259    use crate::{open_db, DB_FILE_NAME};
260
261    // Test bad migration panics
262    #[gpui::test]
263    #[should_panic]
264    async fn test_bad_migration_panics() {
265        enum BadDB {}
266
267        impl Domain for BadDB {
268            fn name() -> &'static str {
269                "db_tests"
270            }
271
272            fn migrations() -> &'static [&'static str] {
273                &[
274                    sql!(CREATE TABLE test(value);),
275                    // failure because test already exists
276                    sql!(CREATE TABLE test(value);),
277                ]
278            }
279        }
280
281        let tempdir = TempDir::new("DbTests").unwrap();
282        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
283    }
284
285    /// Test that DB exists but corrupted (causing recreate)
286    #[gpui::test]
287    async fn test_db_corruption() {
288        enum CorruptedDB {}
289
290        impl Domain for CorruptedDB {
291            fn name() -> &'static str {
292                "db_tests"
293            }
294
295            fn migrations() -> &'static [&'static str] {
296                &[sql!(CREATE TABLE test(value);)]
297            }
298        }
299
300        enum GoodDB {}
301
302        impl Domain for GoodDB {
303            fn name() -> &'static str {
304                "db_tests" //Notice same name
305            }
306
307            fn migrations() -> &'static [&'static str] {
308                &[sql!(CREATE TABLE test2(value);)] //But different migration
309            }
310        }
311
312        let tempdir = TempDir::new("DbTests").unwrap();
313        {
314            let corrupt_db =
315                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
316            assert!(corrupt_db.persistent());
317        }
318
319        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
320        assert!(
321            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
322                .unwrap()
323                .is_none()
324        );
325
326        let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
327            .unwrap()
328            .find(|entry| {
329                !entry
330                    .as_ref()
331                    .unwrap()
332                    .file_name()
333                    .to_str()
334                    .unwrap()
335                    .starts_with("0")
336            })
337            .unwrap()
338            .unwrap()
339            .path();
340        corrupted_backup_dir.push(DB_FILE_NAME);
341
342        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
343        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
344            .unwrap()
345            .is_none());
346    }
347
348    /// Test that DB exists but corrupted (causing recreate)
349    #[gpui::test]
350    async fn test_simultaneous_db_corruption() {
351        enum CorruptedDB {}
352
353        impl Domain for CorruptedDB {
354            fn name() -> &'static str {
355                "db_tests"
356            }
357
358            fn migrations() -> &'static [&'static str] {
359                &[sql!(CREATE TABLE test(value);)]
360            }
361        }
362
363        enum GoodDB {}
364
365        impl Domain for GoodDB {
366            fn name() -> &'static str {
367                "db_tests" //Notice same name
368            }
369
370            fn migrations() -> &'static [&'static str] {
371                &[sql!(CREATE TABLE test2(value);)] //But different migration
372            }
373        }
374
375        let tempdir = TempDir::new("DbTests").unwrap();
376        {
377            // Setup the bad database
378            let corrupt_db =
379                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
380            assert!(corrupt_db.persistent());
381        }
382
383        // Try to connect to it a bunch of times at once
384        let mut guards = vec![];
385        for _ in 0..10 {
386            let tmp_path = tempdir.path().to_path_buf();
387            let guard = thread::spawn(move || {
388                let good_db = smol::block_on(open_db::<GoodDB>(
389                    tmp_path.as_path(),
390                    &util::channel::ReleaseChannel::Dev,
391                ));
392                assert!(
393                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
394                        .unwrap()
395                        .is_none()
396                );
397            });
398
399            guards.push(guard);
400        }
401
402        for guard in guards.into_iter() {
403            assert!(guard.join().is_ok());
404        }
405    }
406}