db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7use gpui::AppContext;
  8pub use indoc::indoc;
  9pub use lazy_static;
 10use parking_lot::{Mutex, RwLock};
 11pub use smol;
 12pub use sqlez;
 13pub use sqlez_macros;
 14pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 15pub use util::paths::DB_DIR;
 16
 17use sqlez::domain::Migrator;
 18use sqlez::thread_safe_connection::ThreadSafeConnection;
 19use sqlez_macros::sql;
 20use std::fs::create_dir_all;
 21use std::future::Future;
 22use std::path::{Path, PathBuf};
 23use std::sync::atomic::{AtomicBool, Ordering};
 24use std::time::{SystemTime, UNIX_EPOCH};
 25use util::channel::ReleaseChannel;
 26use util::{async_iife, ResultExt};
 27
 28const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 29    PRAGMA foreign_keys=TRUE;
 30);
 31
 32const DB_INITIALIZE_QUERY: &'static str = sql!(
 33    PRAGMA journal_mode=WAL;
 34    PRAGMA busy_timeout=1;
 35    PRAGMA case_sensitive_like=TRUE;
 36    PRAGMA synchronous=NORMAL;
 37);
 38
 39const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 40
 41const DB_FILE_NAME: &'static str = "db.sqlite";
 42
 43lazy_static::lazy_static! {
 44    pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
 45    pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
 46    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
 47}
 48static DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
 49
 50/// Open or create a database at the given directory path.
 51/// This will retry a couple times if there are failures. If opening fails once, the db directory
 52/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 53/// In either case, static variables are set so that the user can be notified.
 54pub async fn open_db<M: Migrator + 'static>(
 55    db_dir: &Path,
 56    release_channel: &ReleaseChannel,
 57) -> ThreadSafeConnection<M> {
 58    if *ZED_STATELESS {
 59        return open_fallback_db().await;
 60    }
 61
 62    let release_channel_name = release_channel.dev_name();
 63    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 64
 65    let connection = async_iife!({
 66        // Note: This still has a race condition where 1 set of migrations succeeds
 67        // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
 68        // This will cause the first connection to have the database taken out
 69        // from under it. This *should* be fine though. The second dabatase failure will
 70        // cause errors in the log and so should be observed by developers while writing
 71        // soon-to-be good migrations. If user databases are corrupted, we toss them out
 72        // and try again from a blank. As long as running all migrations from start to end
 73        // on a blank database is ok, this race condition will never be triggered.
 74        //
 75        // Basically: Don't ever push invalid migrations to stable or everyone will have
 76        // a bad time.
 77
 78        // If no db folder, create one at 0-{channel}
 79        create_dir_all(&main_db_dir).context("Could not create db directory")?;
 80        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 81
 82        // Optimistically open databases in parallel
 83        if !DB_FILE_OPERATIONS.is_locked() {
 84            // Try building a connection
 85            if let Some(connection) = open_main_db(&db_path).await {
 86                return Ok(connection)
 87            };
 88        }
 89
 90        // Take a lock in the failure case so that we move the db once per process instead
 91        // of potentially multiple times from different threads. This shouldn't happen in the
 92        // normal path
 93        let _lock = DB_FILE_OPERATIONS.lock();
 94        if let Some(connection) = open_main_db(&db_path).await {
 95            return Ok(connection)
 96        };
 97
 98        let backup_timestamp = SystemTime::now()
 99            .duration_since(UNIX_EPOCH)
100            .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
101            .as_millis();
102
103        // If failed, move 0-{channel} to {current unix timestamp}-{channel}
104        let backup_db_dir = db_dir.join(Path::new(&format!(
105            "{}-{}",
106            backup_timestamp,
107            release_channel_name,
108        )));
109
110        std::fs::rename(&main_db_dir, &backup_db_dir)
111            .context("Failed clean up corrupted database, panicking.")?;
112
113        // Set a static ref with the failed timestamp and error so we can notify the user
114        {
115            let mut guard = BACKUP_DB_PATH.write();
116            *guard = Some(backup_db_dir);
117        }
118
119        // Create a new 0-{channel}
120        create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
121        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
122
123        // Try again
124        open_main_db(&db_path).await.context("Could not newly created db")
125    }).await.log_err();
126
127    if let Some(connection) = connection {
128        return connection;
129    }
130
131    // Set another static ref so that we can escalate the notification
132    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
133
134    // If still failed, create an in memory db with a known name
135    open_fallback_db().await
136}
137
138async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
139    log::info!("Opening main db");
140    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
141        .with_db_initialization_query(DB_INITIALIZE_QUERY)
142        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
143        .build()
144        .await
145        .log_err()
146}
147
148async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
149    log::info!("Opening fallback db");
150    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
151        .with_db_initialization_query(DB_INITIALIZE_QUERY)
152        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
153        .build()
154        .await
155        .expect(
156            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
157        )
158}
159
160#[cfg(any(test, feature = "test-support"))]
161pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
162    use sqlez::thread_safe_connection::locking_queue;
163
164    ThreadSafeConnection::<M>::builder(db_name, false)
165        .with_db_initialization_query(DB_INITIALIZE_QUERY)
166        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
167        // Serialize queued writes via a mutex and run them synchronously
168        .with_write_queue_constructor(locking_queue())
169        .build()
170        .await
171        .unwrap()
172}
173
174/// Implements a basic DB wrapper for a given domain
175#[macro_export]
176macro_rules! define_connection {
177    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
178        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
179
180        impl ::std::ops::Deref for $t {
181            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
182
183            fn deref(&self) -> &Self::Target {
184                &self.0
185            }
186        }
187
188        impl $crate::sqlez::domain::Domain for $t {
189            fn name() -> &'static str {
190                stringify!($t)
191            }
192
193            fn migrations() -> &'static [&'static str] {
194                $migrations
195            }
196        }
197
198        #[cfg(any(test, feature = "test-support"))]
199        $crate::lazy_static::lazy_static! {
200            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
201        }
202
203        #[cfg(not(any(test, feature = "test-support")))]
204        $crate::lazy_static::lazy_static! {
205            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
206        }
207    };
208    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
209        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
210
211        impl ::std::ops::Deref for $t {
212            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
213
214            fn deref(&self) -> &Self::Target {
215                &self.0
216            }
217        }
218
219        impl $crate::sqlez::domain::Domain for $t {
220            fn name() -> &'static str {
221                stringify!($t)
222            }
223
224            fn migrations() -> &'static [&'static str] {
225                $migrations
226            }
227        }
228
229        #[cfg(any(test, feature = "test-support"))]
230        $crate::lazy_static::lazy_static! {
231            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
232        }
233
234        #[cfg(not(any(test, feature = "test-support")))]
235        $crate::lazy_static::lazy_static! {
236            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
237        }
238    };
239}
240
241pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
242where
243    F: Future<Output = anyhow::Result<()>> + Send,
244{
245    cx.background()
246        .spawn(async move { db_write().await.log_err() })
247        .detach()
248}
249
250#[cfg(test)]
251mod tests {
252    use std::{fs, thread};
253
254    use sqlez::{connection::Connection, domain::Domain};
255    use sqlez_macros::sql;
256    use tempdir::TempDir;
257
258    use crate::{open_db, DB_FILE_NAME};
259
260    // Test bad migration panics
261    #[gpui::test]
262    #[should_panic]
263    async fn test_bad_migration_panics() {
264        enum BadDB {}
265
266        impl Domain for BadDB {
267            fn name() -> &'static str {
268                "db_tests"
269            }
270
271            fn migrations() -> &'static [&'static str] {
272                &[
273                    sql!(CREATE TABLE test(value);),
274                    // failure because test already exists
275                    sql!(CREATE TABLE test(value);),
276                ]
277            }
278        }
279
280        let tempdir = TempDir::new("DbTests").unwrap();
281        let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
282    }
283
284    /// Test that DB exists but corrupted (causing recreate)
285    #[gpui::test]
286    async fn test_db_corruption() {
287        enum CorruptedDB {}
288
289        impl Domain for CorruptedDB {
290            fn name() -> &'static str {
291                "db_tests"
292            }
293
294            fn migrations() -> &'static [&'static str] {
295                &[sql!(CREATE TABLE test(value);)]
296            }
297        }
298
299        enum GoodDB {}
300
301        impl Domain for GoodDB {
302            fn name() -> &'static str {
303                "db_tests" //Notice same name
304            }
305
306            fn migrations() -> &'static [&'static str] {
307                &[sql!(CREATE TABLE test2(value);)] //But different migration
308            }
309        }
310
311        let tempdir = TempDir::new("DbTests").unwrap();
312        {
313            let corrupt_db =
314                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
315            assert!(corrupt_db.persistent());
316        }
317
318        let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
319        assert!(
320            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
321                .unwrap()
322                .is_none()
323        );
324
325        let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
326            .unwrap()
327            .find(|entry| {
328                !entry
329                    .as_ref()
330                    .unwrap()
331                    .file_name()
332                    .to_str()
333                    .unwrap()
334                    .starts_with("0")
335            })
336            .unwrap()
337            .unwrap()
338            .path();
339        corrupted_backup_dir.push(DB_FILE_NAME);
340
341        let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
342        assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
343            .unwrap()
344            .is_none());
345    }
346
347    /// Test that DB exists but corrupted (causing recreate)
348    #[gpui::test]
349    async fn test_simultaneous_db_corruption() {
350        enum CorruptedDB {}
351
352        impl Domain for CorruptedDB {
353            fn name() -> &'static str {
354                "db_tests"
355            }
356
357            fn migrations() -> &'static [&'static str] {
358                &[sql!(CREATE TABLE test(value);)]
359            }
360        }
361
362        enum GoodDB {}
363
364        impl Domain for GoodDB {
365            fn name() -> &'static str {
366                "db_tests" //Notice same name
367            }
368
369            fn migrations() -> &'static [&'static str] {
370                &[sql!(CREATE TABLE test2(value);)] //But different migration
371            }
372        }
373
374        let tempdir = TempDir::new("DbTests").unwrap();
375        {
376            // Setup the bad database
377            let corrupt_db =
378                open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
379            assert!(corrupt_db.persistent());
380        }
381
382        // Try to connect to it a bunch of times at once
383        let mut guards = vec![];
384        for _ in 0..10 {
385            let tmp_path = tempdir.path().to_path_buf();
386            let guard = thread::spawn(move || {
387                let good_db = smol::block_on(open_db::<GoodDB>(
388                    tmp_path.as_path(),
389                    &util::channel::ReleaseChannel::Dev,
390                ));
391                assert!(
392                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
393                        .unwrap()
394                        .is_none()
395                );
396            });
397
398            guards.push(guard);
399        }
400
401        for guard in guards.into_iter() {
402            assert!(guard.join().is_ok());
403        }
404    }
405}