db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context as _;
  7use gpui::{App, AppContext};
  8pub use indoc::indoc;
  9pub use paths::database_dir;
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13
 14pub use release_channel::RELEASE_CHANNEL;
 15use sqlez::domain::Migrator;
 16use sqlez::thread_safe_connection::ThreadSafeConnection;
 17use sqlez_macros::sql;
 18use std::future::Future;
 19use std::path::Path;
 20use std::sync::{LazyLock, atomic::Ordering};
 21use std::{env, sync::atomic::AtomicBool};
 22use util::{ResultExt, maybe};
 23
 24const CONNECTION_INITIALIZE_QUERY: &str = sql!(
 25    PRAGMA foreign_keys=TRUE;
 26);
 27
 28const DB_INITIALIZE_QUERY: &str = sql!(
 29    PRAGMA journal_mode=WAL;
 30    PRAGMA busy_timeout=1;
 31    PRAGMA case_sensitive_like=TRUE;
 32    PRAGMA synchronous=NORMAL;
 33);
 34
 35const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
 36
 37const DB_FILE_NAME: &str = "db.sqlite";
 38
 39pub static ZED_STATELESS: LazyLock<bool> =
 40    LazyLock::new(|| env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
 41
 42pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
 43
 44/// Open or create a database at the given directory path.
 45/// This will retry a couple times if there are failures. If opening fails once, the db directory
 46/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 47/// In either case, static variables are set so that the user can be notified.
 48pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
 49    if *ZED_STATELESS {
 50        return open_fallback_db::<M>().await;
 51    }
 52
 53    let main_db_dir = db_dir.join(format!("0-{}", scope));
 54
 55    let connection = maybe!(async {
 56        smol::fs::create_dir_all(&main_db_dir)
 57            .await
 58            .context("Could not create db directory")
 59            .log_err()?;
 60        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 61        open_main_db::<M>(&db_path).await
 62    })
 63    .await;
 64
 65    if let Some(connection) = connection {
 66        return connection;
 67    }
 68
 69    // Set another static ref so that we can escalate the notification
 70    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
 71
 72    // If still failed, create an in memory db with a known name
 73    open_fallback_db::<M>().await
 74}
 75
 76async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
 77    log::trace!("Opening database {}", db_path.display());
 78    ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
 79        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 80        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 81        .build()
 82        .await
 83        .log_err()
 84}
 85
 86async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
 87    log::warn!("Opening fallback in-memory database");
 88    ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
 89        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 90        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 91        .build()
 92        .await
 93        .expect(
 94            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
 95        )
 96}
 97
 98#[cfg(any(test, feature = "test-support"))]
 99pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
100    use sqlez::thread_safe_connection::locking_queue;
101
102    ThreadSafeConnection::builder::<M>(db_name, false)
103        .with_db_initialization_query(DB_INITIALIZE_QUERY)
104        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
105        // Serialize queued writes via a mutex and run them synchronously
106        .with_write_queue_constructor(locking_queue())
107        .build()
108        .await
109        .unwrap()
110}
111
112/// Implements a basic DB wrapper for a given domain
113#[macro_export]
114macro_rules! define_connection {
115    (pub static ref $id:ident: $t:ident<()> = $migrations:expr; $($global:ident)?) => {
116        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
117
118        impl ::std::ops::Deref for $t {
119            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
120
121            fn deref(&self) -> &Self::Target {
122                &self.0
123            }
124        }
125
126        impl $crate::sqlez::domain::Domain for $t {
127            fn name() -> &'static str {
128                stringify!($t)
129            }
130
131            fn migrations() -> &'static [&'static str] {
132                $migrations
133            }
134        }
135
136        impl $t {
137            #[cfg(any(test, feature = "test-support"))]
138            pub async fn open_test_db(name: &'static str) -> Self {
139                $t($crate::open_test_db::<$t>(name).await)
140            }
141        }
142
143        #[cfg(any(test, feature = "test-support"))]
144        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
145            $t($crate::smol::block_on($crate::open_test_db::<$t>(stringify!($id))))
146        });
147
148        #[cfg(not(any(test, feature = "test-support")))]
149        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
150            let db_dir = $crate::database_dir();
151            let scope = if false $(|| stringify!($global) == "global")? {
152                "global"
153            } else {
154                $crate::RELEASE_CHANNEL.dev_name()
155            };
156            $t($crate::smol::block_on($crate::open_db::<$t>(db_dir, scope)))
157        });
158    };
159    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr; $($global:ident)?) => {
160        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection);
161
162        impl ::std::ops::Deref for $t {
163            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
164
165            fn deref(&self) -> &Self::Target {
166                &self.0
167            }
168        }
169
170        impl $crate::sqlez::domain::Domain for $t {
171            fn name() -> &'static str {
172                stringify!($t)
173            }
174
175            fn migrations() -> &'static [&'static str] {
176                $migrations
177            }
178        }
179
180        #[cfg(any(test, feature = "test-support"))]
181        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
182            $t($crate::smol::block_on($crate::open_test_db::<($($d),+, $t)>(stringify!($id))))
183        });
184
185        #[cfg(not(any(test, feature = "test-support")))]
186        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
187            let db_dir = $crate::database_dir();
188            let scope = if false $(|| stringify!($global) == "global")? {
189                "global"
190            } else {
191                $crate::RELEASE_CHANNEL.dev_name()
192            };
193            $t($crate::smol::block_on($crate::open_db::<($($d),+, $t)>(db_dir, scope)))
194        });
195    };
196}
197
198pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
199where
200    F: Future<Output = anyhow::Result<()>> + Send,
201{
202    cx.background_spawn(async move { db_write().await.log_err() })
203        .detach()
204}
205
206#[cfg(test)]
207mod tests {
208    use std::thread;
209
210    use sqlez::domain::Domain;
211    use sqlez_macros::sql;
212
213    use crate::open_db;
214
215    // Test bad migration panics
216    #[gpui::test]
217    #[should_panic]
218    async fn test_bad_migration_panics() {
219        enum BadDB {}
220
221        impl Domain for BadDB {
222            fn name() -> &'static str {
223                "db_tests"
224            }
225
226            fn migrations() -> &'static [&'static str] {
227                &[
228                    sql!(CREATE TABLE test(value);),
229                    // failure because test already exists
230                    sql!(CREATE TABLE test(value);),
231                ]
232            }
233        }
234
235        let tempdir = tempfile::Builder::new()
236            .prefix("DbTests")
237            .tempdir()
238            .unwrap();
239        let _bad_db = open_db::<BadDB>(
240            tempdir.path(),
241            release_channel::ReleaseChannel::Dev.dev_name(),
242        )
243        .await;
244    }
245
246    /// Test that DB exists but corrupted (causing recreate)
247    #[gpui::test]
248    async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
249        cx.executor().allow_parking();
250
251        enum CorruptedDB {}
252
253        impl Domain for CorruptedDB {
254            fn name() -> &'static str {
255                "db_tests"
256            }
257
258            fn migrations() -> &'static [&'static str] {
259                &[sql!(CREATE TABLE test(value);)]
260            }
261        }
262
263        enum GoodDB {}
264
265        impl Domain for GoodDB {
266            fn name() -> &'static str {
267                "db_tests" //Notice same name
268            }
269
270            fn migrations() -> &'static [&'static str] {
271                &[sql!(CREATE TABLE test2(value);)] //But different migration
272            }
273        }
274
275        let tempdir = tempfile::Builder::new()
276            .prefix("DbTests")
277            .tempdir()
278            .unwrap();
279        {
280            let corrupt_db = open_db::<CorruptedDB>(
281                tempdir.path(),
282                release_channel::ReleaseChannel::Dev.dev_name(),
283            )
284            .await;
285            assert!(corrupt_db.persistent());
286        }
287
288        let good_db = open_db::<GoodDB>(
289            tempdir.path(),
290            release_channel::ReleaseChannel::Dev.dev_name(),
291        )
292        .await;
293        assert!(
294            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
295                .unwrap()
296                .is_none()
297        );
298    }
299
300    /// Test that DB exists but corrupted (causing recreate)
301    #[gpui::test(iterations = 30)]
302    async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
303        cx.executor().allow_parking();
304
305        enum CorruptedDB {}
306
307        impl Domain for CorruptedDB {
308            fn name() -> &'static str {
309                "db_tests"
310            }
311
312            fn migrations() -> &'static [&'static str] {
313                &[sql!(CREATE TABLE test(value);)]
314            }
315        }
316
317        enum GoodDB {}
318
319        impl Domain for GoodDB {
320            fn name() -> &'static str {
321                "db_tests" //Notice same name
322            }
323
324            fn migrations() -> &'static [&'static str] {
325                &[sql!(CREATE TABLE test2(value);)] //But different migration
326            }
327        }
328
329        let tempdir = tempfile::Builder::new()
330            .prefix("DbTests")
331            .tempdir()
332            .unwrap();
333        {
334            // Setup the bad database
335            let corrupt_db = open_db::<CorruptedDB>(
336                tempdir.path(),
337                release_channel::ReleaseChannel::Dev.dev_name(),
338            )
339            .await;
340            assert!(corrupt_db.persistent());
341        }
342
343        // Try to connect to it a bunch of times at once
344        let mut guards = vec![];
345        for _ in 0..10 {
346            let tmp_path = tempdir.path().to_path_buf();
347            let guard = thread::spawn(move || {
348                let good_db = smol::block_on(open_db::<GoodDB>(
349                    tmp_path.as_path(),
350                    release_channel::ReleaseChannel::Dev.dev_name(),
351                ));
352                assert!(
353                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
354                        .unwrap()
355                        .is_none()
356                );
357            });
358
359            guards.push(guard);
360        }
361
362        for guard in guards.into_iter() {
363            assert!(guard.join().is_ok());
364        }
365    }
366}