db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7use gpui::AppContext;
  8pub use indoc::indoc;
  9pub use paths::database_dir;
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13
 14use release_channel::ReleaseChannel;
 15pub use release_channel::RELEASE_CHANNEL;
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::env;
 20use std::future::Future;
 21use std::path::Path;
 22use std::sync::atomic::{AtomicBool, Ordering};
 23use std::sync::LazyLock;
 24use util::{maybe, ResultExt};
 25
 26const CONNECTION_INITIALIZE_QUERY: &str = sql!(
 27    PRAGMA foreign_keys=TRUE;
 28);
 29
 30const DB_INITIALIZE_QUERY: &str = sql!(
 31    PRAGMA journal_mode=WAL;
 32    PRAGMA busy_timeout=1;
 33    PRAGMA case_sensitive_like=TRUE;
 34    PRAGMA synchronous=NORMAL;
 35);
 36
 37const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
 38
 39const DB_FILE_NAME: &str = "db.sqlite";
 40
 41pub static ZED_STATELESS: LazyLock<bool> =
 42    LazyLock::new(|| env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
 43
 44pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
 45
 46/// Open or create a database at the given directory path.
 47/// This will retry a couple times if there are failures. If opening fails once, the db directory
 48/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 49/// In either case, static variables are set so that the user can be notified.
 50pub async fn open_db<M: Migrator + 'static>(
 51    db_dir: &Path,
 52    release_channel: &ReleaseChannel,
 53) -> ThreadSafeConnection<M> {
 54    if *ZED_STATELESS {
 55        return open_fallback_db().await;
 56    }
 57
 58    let release_channel_name = release_channel.dev_name();
 59    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 60
 61    let connection = maybe!(async {
 62        smol::fs::create_dir_all(&main_db_dir)
 63            .await
 64            .context("Could not create db directory")
 65            .log_err()?;
 66        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 67        open_main_db(&db_path).await
 68    })
 69    .await;
 70
 71    if let Some(connection) = connection {
 72        return connection;
 73    }
 74
 75    // Set another static ref so that we can escalate the notification
 76    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
 77
 78    // If still failed, create an in memory db with a known name
 79    open_fallback_db().await
 80}
 81
 82async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection<M>> {
 83    log::info!("Opening main db");
 84    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
 85        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 86        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 87        .build()
 88        .await
 89        .log_err()
 90}
 91
 92async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
 93    log::info!("Opening fallback db");
 94    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
 95        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 96        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 97        .build()
 98        .await
 99        .expect(
100            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
101        )
102}
103
104#[cfg(any(test, feature = "test-support"))]
105pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
106    use sqlez::thread_safe_connection::locking_queue;
107
108    ThreadSafeConnection::<M>::builder(db_name, false)
109        .with_db_initialization_query(DB_INITIALIZE_QUERY)
110        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
111        // Serialize queued writes via a mutex and run them synchronously
112        .with_write_queue_constructor(locking_queue())
113        .build()
114        .await
115        .unwrap()
116}
117
118/// Implements a basic DB wrapper for a given domain
119#[macro_export]
120macro_rules! define_connection {
121    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
122        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
123
124        impl ::std::ops::Deref for $t {
125            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
126
127            fn deref(&self) -> &Self::Target {
128                &self.0
129            }
130        }
131
132        impl $crate::sqlez::domain::Domain for $t {
133            fn name() -> &'static str {
134                stringify!($t)
135            }
136
137            fn migrations() -> &'static [&'static str] {
138                $migrations
139            }
140        }
141
142        use std::sync::LazyLock;
143        #[cfg(any(test, feature = "test-support"))]
144        pub static $id: LazyLock<$t> = LazyLock::new(|| {
145            $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
146        });
147
148        #[cfg(not(any(test, feature = "test-support")))]
149        pub static $id: LazyLock<$t> = LazyLock::new(|| {
150            $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)))
151        });
152    };
153    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
154        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
155
156        impl ::std::ops::Deref for $t {
157            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
158
159            fn deref(&self) -> &Self::Target {
160                &self.0
161            }
162        }
163
164        impl $crate::sqlez::domain::Domain for $t {
165            fn name() -> &'static str {
166                stringify!($t)
167            }
168
169            fn migrations() -> &'static [&'static str] {
170                $migrations
171            }
172        }
173
174        #[cfg(any(test, feature = "test-support"))]
175        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
176            $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
177        });
178
179        #[cfg(not(any(test, feature = "test-support")))]
180        pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
181            $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)))
182        });
183    };
184}
185
186pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
187where
188    F: Future<Output = anyhow::Result<()>> + Send,
189{
190    cx.background_executor()
191        .spawn(async move { db_write().await.log_err() })
192        .detach()
193}
194
195#[cfg(test)]
196mod tests {
197    use std::thread;
198
199    use sqlez::domain::Domain;
200    use sqlez_macros::sql;
201
202    use crate::open_db;
203
204    // Test bad migration panics
205    #[gpui::test]
206    #[should_panic]
207    async fn test_bad_migration_panics() {
208        enum BadDB {}
209
210        impl Domain for BadDB {
211            fn name() -> &'static str {
212                "db_tests"
213            }
214
215            fn migrations() -> &'static [&'static str] {
216                &[
217                    sql!(CREATE TABLE test(value);),
218                    // failure because test already exists
219                    sql!(CREATE TABLE test(value);),
220                ]
221            }
222        }
223
224        let tempdir = tempfile::Builder::new()
225            .prefix("DbTests")
226            .tempdir()
227            .unwrap();
228        let _bad_db = open_db::<BadDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
229    }
230
231    /// Test that DB exists but corrupted (causing recreate)
232    #[gpui::test]
233    async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
234        cx.executor().allow_parking();
235
236        enum CorruptedDB {}
237
238        impl Domain for CorruptedDB {
239            fn name() -> &'static str {
240                "db_tests"
241            }
242
243            fn migrations() -> &'static [&'static str] {
244                &[sql!(CREATE TABLE test(value);)]
245            }
246        }
247
248        enum GoodDB {}
249
250        impl Domain for GoodDB {
251            fn name() -> &'static str {
252                "db_tests" //Notice same name
253            }
254
255            fn migrations() -> &'static [&'static str] {
256                &[sql!(CREATE TABLE test2(value);)] //But different migration
257            }
258        }
259
260        let tempdir = tempfile::Builder::new()
261            .prefix("DbTests")
262            .tempdir()
263            .unwrap();
264        {
265            let corrupt_db =
266                open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
267            assert!(corrupt_db.persistent());
268        }
269
270        let good_db =
271            open_db::<GoodDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
272        assert!(
273            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
274                .unwrap()
275                .is_none()
276        );
277    }
278
279    /// Test that DB exists but corrupted (causing recreate)
280    #[gpui::test(iterations = 30)]
281    async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
282        cx.executor().allow_parking();
283
284        enum CorruptedDB {}
285
286        impl Domain for CorruptedDB {
287            fn name() -> &'static str {
288                "db_tests"
289            }
290
291            fn migrations() -> &'static [&'static str] {
292                &[sql!(CREATE TABLE test(value);)]
293            }
294        }
295
296        enum GoodDB {}
297
298        impl Domain for GoodDB {
299            fn name() -> &'static str {
300                "db_tests" //Notice same name
301            }
302
303            fn migrations() -> &'static [&'static str] {
304                &[sql!(CREATE TABLE test2(value);)] //But different migration
305            }
306        }
307
308        let tempdir = tempfile::Builder::new()
309            .prefix("DbTests")
310            .tempdir()
311            .unwrap();
312        {
313            // Setup the bad database
314            let corrupt_db =
315                open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
316            assert!(corrupt_db.persistent());
317        }
318
319        // Try to connect to it a bunch of times at once
320        let mut guards = vec![];
321        for _ in 0..10 {
322            let tmp_path = tempdir.path().to_path_buf();
323            let guard = thread::spawn(move || {
324                let good_db = smol::block_on(open_db::<GoodDB>(
325                    tmp_path.as_path(),
326                    &release_channel::ReleaseChannel::Dev,
327                ));
328                assert!(
329                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
330                        .unwrap()
331                        .is_none()
332                );
333            });
334
335            guards.push(guard);
336        }
337
338        for guard in guards.into_iter() {
339            assert!(guard.join().is_ok());
340        }
341    }
342}