db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context as _;
  7use gpui::{App, AppContext};
  8pub use indoc::indoc;
  9pub use paths::database_dir;
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13
 14pub use release_channel::RELEASE_CHANNEL;
 15use sqlez::domain::Migrator;
 16use sqlez::thread_safe_connection::ThreadSafeConnection;
 17use sqlez_macros::sql;
 18use std::future::Future;
 19use std::path::Path;
 20use std::sync::atomic::AtomicBool;
 21use std::sync::{LazyLock, atomic::Ordering};
 22use util::{ResultExt, maybe};
 23use zed_env_vars::ZED_STATELESS;
 24
 25const CONNECTION_INITIALIZE_QUERY: &str = sql!(
 26    PRAGMA foreign_keys=TRUE;
 27);
 28
 29const DB_INITIALIZE_QUERY: &str = sql!(
 30    PRAGMA journal_mode=WAL;
 31    PRAGMA busy_timeout=1;
 32    PRAGMA case_sensitive_like=TRUE;
 33    PRAGMA synchronous=NORMAL;
 34);
 35
 36const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
 37
 38const DB_FILE_NAME: &str = "db.sqlite";
 39
 40pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
 41
 42/// Open or create a database at the given directory path.
 43/// This will retry a couple times if there are failures. If opening fails once, the db directory
 44/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 45/// In either case, static variables are set so that the user can be notified.
 46pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
 47    if *ZED_STATELESS {
 48        return open_fallback_db::<M>().await;
 49    }
 50
 51    let main_db_dir = db_dir.join(format!("0-{}", scope));
 52
 53    let connection = maybe!(async {
 54        smol::fs::create_dir_all(&main_db_dir)
 55            .await
 56            .context("Could not create db directory")
 57            .log_err()?;
 58        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 59        open_main_db::<M>(&db_path).await
 60    })
 61    .await;
 62
 63    if let Some(connection) = connection {
 64        return connection;
 65    }
 66
 67    // Set another static ref so that we can escalate the notification
 68    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
 69
 70    // If still failed, create an in memory db with a known name
 71    open_fallback_db::<M>().await
 72}
 73
 74async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
 75    log::trace!("Opening database {}", db_path.display());
 76    ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
 77        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 78        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 79        .build()
 80        .await
 81        .log_err()
 82}
 83
 84async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
 85    log::warn!("Opening fallback in-memory database");
 86    ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
 87        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 88        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 89        .build()
 90        .await
 91        .expect(
 92            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
 93        )
 94}
 95
 96#[cfg(any(test, feature = "test-support"))]
 97pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
 98    use sqlez::thread_safe_connection::locking_queue;
 99
100    ThreadSafeConnection::builder::<M>(db_name, false)
101        .with_db_initialization_query(DB_INITIALIZE_QUERY)
102        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
103        // Serialize queued writes via a mutex and run them synchronously
104        .with_write_queue_constructor(locking_queue())
105        .build()
106        .await
107        .unwrap()
108}
109
110/// Implements a basic DB wrapper for a given domain
111///
112/// Arguments:
113/// - static variable name for connection
114/// - type of connection wrapper
115/// - dependencies, whose migrations should be run prior to this domain's migrations
116#[macro_export]
117macro_rules! static_connection {
118    ($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
119        impl ::std::ops::Deref for $t {
120            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
121
122            fn deref(&self) -> &Self::Target {
123                &self.0
124            }
125        }
126
127        impl $t {
128            #[cfg(any(test, feature = "test-support"))]
129            pub async fn open_test_db(name: &'static str) -> Self {
130                $t($crate::open_test_db::<$t>(name).await)
131            }
132        }
133
134        #[cfg(any(test, feature = "test-support"))]
135        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
136            #[allow(unused_parens)]
137            $t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
138        });
139
140        #[cfg(not(any(test, feature = "test-support")))]
141        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
142            let db_dir = $crate::database_dir();
143            let scope = if false $(|| stringify!($global) == "global")? {
144                "global"
145            } else {
146                $crate::RELEASE_CHANNEL.dev_name()
147            };
148            #[allow(unused_parens)]
149            $t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
150        });
151    }
152}
153
154pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
155where
156    F: Future<Output = anyhow::Result<()>> + Send,
157{
158    cx.background_spawn(async move { db_write().await.log_err() })
159        .detach()
160}
161
162#[cfg(test)]
163mod tests {
164    use std::thread;
165
166    use sqlez::domain::Domain;
167    use sqlez_macros::sql;
168
169    use crate::open_db;
170
171    // Test bad migration panics
172    #[gpui::test]
173    #[should_panic]
174    async fn test_bad_migration_panics() {
175        enum BadDB {}
176
177        impl Domain for BadDB {
178            const NAME: &str = "db_tests";
179            const MIGRATIONS: &[&str] = &[
180                sql!(CREATE TABLE test(value);),
181                // failure because test already exists
182                sql!(CREATE TABLE test(value);),
183            ];
184        }
185
186        let tempdir = tempfile::Builder::new()
187            .prefix("DbTests")
188            .tempdir()
189            .unwrap();
190        let _bad_db = open_db::<BadDB>(
191            tempdir.path(),
192            release_channel::ReleaseChannel::Dev.dev_name(),
193        )
194        .await;
195    }
196
197    /// Test that DB exists but corrupted (causing recreate)
198    #[gpui::test]
199    async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
200        cx.executor().allow_parking();
201
202        enum CorruptedDB {}
203
204        impl Domain for CorruptedDB {
205            const NAME: &str = "db_tests";
206            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
207        }
208
209        enum GoodDB {}
210
211        impl Domain for GoodDB {
212            const NAME: &str = "db_tests"; //Notice same name
213            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
214        }
215
216        let tempdir = tempfile::Builder::new()
217            .prefix("DbTests")
218            .tempdir()
219            .unwrap();
220        {
221            let corrupt_db = open_db::<CorruptedDB>(
222                tempdir.path(),
223                release_channel::ReleaseChannel::Dev.dev_name(),
224            )
225            .await;
226            assert!(corrupt_db.persistent());
227        }
228
229        let good_db = open_db::<GoodDB>(
230            tempdir.path(),
231            release_channel::ReleaseChannel::Dev.dev_name(),
232        )
233        .await;
234        assert!(
235            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
236                .unwrap()
237                .is_none()
238        );
239    }
240
241    /// Test that DB exists but corrupted (causing recreate)
242    #[gpui::test(iterations = 30)]
243    async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
244        cx.executor().allow_parking();
245
246        enum CorruptedDB {}
247
248        impl Domain for CorruptedDB {
249            const NAME: &str = "db_tests";
250
251            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
252        }
253
254        enum GoodDB {}
255
256        impl Domain for GoodDB {
257            const NAME: &str = "db_tests"; //Notice same name
258            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
259        }
260
261        let tempdir = tempfile::Builder::new()
262            .prefix("DbTests")
263            .tempdir()
264            .unwrap();
265        {
266            // Setup the bad database
267            let corrupt_db = open_db::<CorruptedDB>(
268                tempdir.path(),
269                release_channel::ReleaseChannel::Dev.dev_name(),
270            )
271            .await;
272            assert!(corrupt_db.persistent());
273        }
274
275        // Try to connect to it a bunch of times at once
276        let mut guards = vec![];
277        for _ in 0..10 {
278            let tmp_path = tempdir.path().to_path_buf();
279            let guard = thread::spawn(move || {
280                let good_db = smol::block_on(open_db::<GoodDB>(
281                    tmp_path.as_path(),
282                    release_channel::ReleaseChannel::Dev.dev_name(),
283                ));
284                assert!(
285                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
286                        .unwrap()
287                        .is_none()
288                );
289            });
290
291            guards.push(guard);
292        }
293
294        for guard in guards.into_iter() {
295            assert!(guard.join().is_ok());
296        }
297    }
298}