db2.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context;
  7use gpui2::AppContext;
  8pub use indoc::indoc;
  9pub use lazy_static;
 10pub use smol;
 11pub use sqlez;
 12pub use sqlez_macros;
 13pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 14pub use util::paths::DB_DIR;
 15
 16use sqlez::domain::Migrator;
 17use sqlez::thread_safe_connection::ThreadSafeConnection;
 18use sqlez_macros::sql;
 19use std::future::Future;
 20use std::path::{Path, PathBuf};
 21use std::sync::atomic::{AtomicBool, Ordering};
 22use util::channel::ReleaseChannel;
 23use util::{async_maybe, ResultExt};
 24
 25const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
 26    PRAGMA foreign_keys=TRUE;
 27);
 28
 29const DB_INITIALIZE_QUERY: &'static str = sql!(
 30    PRAGMA journal_mode=WAL;
 31    PRAGMA busy_timeout=1;
 32    PRAGMA case_sensitive_like=TRUE;
 33    PRAGMA synchronous=NORMAL;
 34);
 35
 36const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
 37
 38const DB_FILE_NAME: &'static str = "db.sqlite";
 39
 40lazy_static::lazy_static! {
 41    pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
 42    pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
 43}
 44
 45/// Open or create a database at the given directory path.
 46/// This will retry a couple times if there are failures. If opening fails once, the db directory
 47/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
 48/// In either case, static variables are set so that the user can be notified.
 49pub async fn open_db<M: Migrator + 'static>(
 50    db_dir: &Path,
 51    release_channel: &ReleaseChannel,
 52) -> ThreadSafeConnection<M> {
 53    if *ZED_STATELESS {
 54        return open_fallback_db().await;
 55    }
 56
 57    let release_channel_name = release_channel.dev_name();
 58    let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
 59
 60    let connection = async_maybe!({
 61        smol::fs::create_dir_all(&main_db_dir)
 62            .await
 63            .context("Could not create db directory")
 64            .log_err()?;
 65        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
 66        open_main_db(&db_path).await
 67    })
 68    .await;
 69
 70    if let Some(connection) = connection {
 71        return connection;
 72    }
 73
 74    // Set another static ref so that we can escalate the notification
 75    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
 76
 77    // If still failed, create an in memory db with a known name
 78    open_fallback_db().await
 79}
 80
 81async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
 82    log::info!("Opening main db");
 83    ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
 84        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 85        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 86        .build()
 87        .await
 88        .log_err()
 89}
 90
 91async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
 92    log::info!("Opening fallback db");
 93    ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
 94        .with_db_initialization_query(DB_INITIALIZE_QUERY)
 95        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
 96        .build()
 97        .await
 98        .expect(
 99            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
100        )
101}
102
103#[cfg(any(test, feature = "test-support"))]
104pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
105    use sqlez::thread_safe_connection::locking_queue;
106
107    ThreadSafeConnection::<M>::builder(db_name, false)
108        .with_db_initialization_query(DB_INITIALIZE_QUERY)
109        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
110        // Serialize queued writes via a mutex and run them synchronously
111        .with_write_queue_constructor(locking_queue())
112        .build()
113        .await
114        .unwrap()
115}
116
117/// Implements a basic DB wrapper for a given domain
118#[macro_export]
119macro_rules! define_connection {
120    (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
121        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
122
123        impl ::std::ops::Deref for $t {
124            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
125
126            fn deref(&self) -> &Self::Target {
127                &self.0
128            }
129        }
130
131        impl $crate::sqlez::domain::Domain for $t {
132            fn name() -> &'static str {
133                stringify!($t)
134            }
135
136            fn migrations() -> &'static [&'static str] {
137                $migrations
138            }
139        }
140
141        #[cfg(any(test, feature = "test-support"))]
142        $crate::lazy_static::lazy_static! {
143            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
144        }
145
146        #[cfg(not(any(test, feature = "test-support")))]
147        $crate::lazy_static::lazy_static! {
148            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
149        }
150    };
151    (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
152        pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
153
154        impl ::std::ops::Deref for $t {
155            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
156
157            fn deref(&self) -> &Self::Target {
158                &self.0
159            }
160        }
161
162        impl $crate::sqlez::domain::Domain for $t {
163            fn name() -> &'static str {
164                stringify!($t)
165            }
166
167            fn migrations() -> &'static [&'static str] {
168                $migrations
169            }
170        }
171
172        #[cfg(any(test, feature = "test-support"))]
173        $crate::lazy_static::lazy_static! {
174            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
175        }
176
177        #[cfg(not(any(test, feature = "test-support")))]
178        $crate::lazy_static::lazy_static! {
179            pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
180        }
181    };
182}
183
184pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
185where
186    F: Future<Output = anyhow::Result<()>> + Send,
187{
188    cx.executor()
189        .spawn(async move { db_write().await.log_err() })
190        .detach()
191}
192
193// #[cfg(test)]
194// mod tests {
195//     use std::thread;
196
197//     use sqlez::domain::Domain;
198//     use sqlez_macros::sql;
199//     use tempdir::TempDir;
200
201//     use crate::open_db;
202
203//     // Test bad migration panics
204//     #[gpui::test]
205//     #[should_panic]
206//     async fn test_bad_migration_panics() {
207//         enum BadDB {}
208
209//         impl Domain for BadDB {
210//             fn name() -> &'static str {
211//                 "db_tests"
212//             }
213
214//             fn migrations() -> &'static [&'static str] {
215//                 &[
216//                     sql!(CREATE TABLE test(value);),
217//                     // failure because test already exists
218//                     sql!(CREATE TABLE test(value);),
219//                 ]
220//             }
221//         }
222
223//         let tempdir = TempDir::new("DbTests").unwrap();
224//         let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
225//     }
226
227//     /// Test that DB exists but corrupted (causing recreate)
228//     #[gpui::test]
229//     async fn test_db_corruption() {
230//         enum CorruptedDB {}
231
232//         impl Domain for CorruptedDB {
233//             fn name() -> &'static str {
234//                 "db_tests"
235//             }
236
237//             fn migrations() -> &'static [&'static str] {
238//                 &[sql!(CREATE TABLE test(value);)]
239//             }
240//         }
241
242//         enum GoodDB {}
243
244//         impl Domain for GoodDB {
245//             fn name() -> &'static str {
246//                 "db_tests" //Notice same name
247//             }
248
249//             fn migrations() -> &'static [&'static str] {
250//                 &[sql!(CREATE TABLE test2(value);)] //But different migration
251//             }
252//         }
253
254//         let tempdir = TempDir::new("DbTests").unwrap();
255//         {
256//             let corrupt_db =
257//                 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
258//             assert!(corrupt_db.persistent());
259//         }
260
261//         let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
262//         assert!(
263//             good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
264//                 .unwrap()
265//                 .is_none()
266//         );
267//     }
268
269//     /// Test that DB exists but corrupted (causing recreate)
270//     #[gpui::test(iterations = 30)]
271//     async fn test_simultaneous_db_corruption() {
272//         enum CorruptedDB {}
273
274//         impl Domain for CorruptedDB {
275//             fn name() -> &'static str {
276//                 "db_tests"
277//             }
278
279//             fn migrations() -> &'static [&'static str] {
280//                 &[sql!(CREATE TABLE test(value);)]
281//             }
282//         }
283
284//         enum GoodDB {}
285
286//         impl Domain for GoodDB {
287//             fn name() -> &'static str {
288//                 "db_tests" //Notice same name
289//             }
290
291//             fn migrations() -> &'static [&'static str] {
292//                 &[sql!(CREATE TABLE test2(value);)] //But different migration
293//             }
294//         }
295
296//         let tempdir = TempDir::new("DbTests").unwrap();
297//         {
298//             // Setup the bad database
299//             let corrupt_db =
300//                 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
301//             assert!(corrupt_db.persistent());
302//         }
303
304//         // Try to connect to it a bunch of times at once
305//         let mut guards = vec![];
306//         for _ in 0..10 {
307//             let tmp_path = tempdir.path().to_path_buf();
308//             let guard = thread::spawn(move || {
309//                 let good_db = smol::block_on(open_db::<GoodDB>(
310//                     tmp_path.as_path(),
311//                     &util::channel::ReleaseChannel::Dev,
312//                 ));
313//                 assert!(
314//                     good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
315//                         .unwrap()
316//                         .is_none()
317//                 );
318//             });
319
320//             guards.push(guard);
321//         }
322
323//         for guard in guards.into_iter() {
324//             assert!(guard.join().is_ok());
325//         }
326//     }
327// }