db.rs

  1pub mod kvp;
  2pub mod query;
  3
  4// Re-export
  5pub use anyhow;
  6use anyhow::Context as _;
  7pub use gpui;
  8use gpui::{App, AppContext, Global};
  9pub use indoc::indoc;
 10pub use inventory;
 11pub use paths::database_dir;
 12pub use smol;
 13pub use sqlez;
 14pub use sqlez_macros;
 15pub use uuid;
 16
 17pub use release_channel::RELEASE_CHANNEL;
 18use sqlez::domain::Migrator;
 19use sqlez::thread_safe_connection::ThreadSafeConnection;
 20use sqlez_macros::sql;
 21use std::future::Future;
 22use std::path::Path;
 23use std::sync::atomic::AtomicBool;
 24use std::sync::{LazyLock, atomic::Ordering};
 25use util::{ResultExt, maybe};
 26use zed_env_vars::ZED_STATELESS;
 27
 28/// A migration registered via `static_connection!` and collected at link time.
 29pub struct DomainMigration {
 30    pub name: &'static str,
 31    pub migrations: &'static [&'static str],
 32    pub dependencies: &'static [&'static str],
 33    pub should_allow_migration_change: fn(usize, &str, &str) -> bool,
 34}
 35
 36inventory::collect!(DomainMigration);
 37
 38/// The shared database connection backing all domain-specific DB wrappers.
 39/// Set as a GPUI global per-App. Falls back to a shared LazyLock if not set.
 40pub struct AppDatabase(pub ThreadSafeConnection);
 41
 42impl Global for AppDatabase {}
 43
 44/// Migrator that runs all inventory-registered domain migrations.
 45pub struct AppMigrator;
 46
 47impl Migrator for AppMigrator {
 48    fn migrate(connection: &sqlez::connection::Connection) -> anyhow::Result<()> {
 49        let registrations: Vec<&DomainMigration> = inventory::iter::<DomainMigration>().collect();
 50        let sorted = topological_sort(&registrations);
 51        for reg in &sorted {
 52            let mut should_allow = reg.should_allow_migration_change;
 53            connection.migrate(reg.name, reg.migrations, &mut should_allow)?;
 54        }
 55        Ok(())
 56    }
 57}
 58
 59impl AppDatabase {
 60    /// Opens the production database and runs all inventory-registered
 61    /// migrations in dependency order.
 62    pub fn new() -> Self {
 63        let db_dir = database_dir();
 64        let scope = RELEASE_CHANNEL.dev_name();
 65        let connection = smol::block_on(open_db::<AppMigrator>(db_dir, scope));
 66        Self(connection)
 67    }
 68
 69    /// Creates a new in-memory database with a unique name and runs all
 70    /// inventory-registered migrations in dependency order.
 71    #[cfg(any(test, feature = "test-support"))]
 72    pub fn test_new() -> Self {
 73        let name = format!("test-db-{}", uuid::Uuid::new_v4());
 74        let connection = smol::block_on(open_test_db::<AppMigrator>(&name));
 75        Self(connection)
 76    }
 77
 78    /// Returns the per-App connection if set, otherwise falls back to
 79    /// the shared LazyLock.
 80    pub fn global(cx: &App) -> &ThreadSafeConnection {
 81        #[allow(unreachable_code)]
 82        if let Some(db) = cx.try_global::<Self>() {
 83            return &db.0;
 84        } else {
 85            #[cfg(any(feature = "test-support", test))]
 86            return &TEST_APP_DATABASE.0;
 87
 88            panic!("database not initialized")
 89        }
 90    }
 91}
 92
 93fn topological_sort<'a>(registrations: &[&'a DomainMigration]) -> Vec<&'a DomainMigration> {
 94    let mut sorted: Vec<&DomainMigration> = Vec::new();
 95    let mut visited: std::collections::HashSet<&str> = std::collections::HashSet::new();
 96
 97    fn visit<'a>(
 98        name: &str,
 99        registrations: &[&'a DomainMigration],
100        sorted: &mut Vec<&'a DomainMigration>,
101        visited: &mut std::collections::HashSet<&'a str>,
102    ) {
103        if visited.contains(name) {
104            return;
105        }
106        if let Some(reg) = registrations.iter().find(|r| r.name == name) {
107            for dep in reg.dependencies {
108                visit(dep, registrations, sorted, visited);
109            }
110            visited.insert(reg.name);
111            sorted.push(reg);
112        }
113    }
114
115    for reg in registrations {
116        visit(reg.name, registrations, &mut sorted, &mut visited);
117    }
118    sorted
119}
120
121/// Shared fallback `AppDatabase` used when no per-App global is set.
122#[cfg(any(test, feature = "test-support"))]
123static TEST_APP_DATABASE: LazyLock<AppDatabase> = LazyLock::new(AppDatabase::test_new);
124
125const CONNECTION_INITIALIZE_QUERY: &str = sql!(
126    PRAGMA foreign_keys=TRUE;
127);
128
129const DB_INITIALIZE_QUERY: &str = sql!(
130    PRAGMA journal_mode=WAL;
131    PRAGMA busy_timeout=500;
132    PRAGMA case_sensitive_like=TRUE;
133    PRAGMA synchronous=NORMAL;
134);
135
136const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
137
138const DB_FILE_NAME: &str = "db.sqlite";
139
140pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
141
142/// Open or create a database at the given directory path.
143/// This will retry a couple times if there are failures. If opening fails once, the db directory
144/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
145/// In either case, static variables are set so that the user can be notified.
146pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
147    if *ZED_STATELESS {
148        return open_fallback_db::<M>().await;
149    }
150
151    let main_db_dir = db_dir.join(format!("0-{}", scope));
152
153    let connection = maybe!(async {
154        smol::fs::create_dir_all(&main_db_dir)
155            .await
156            .context("Could not create db directory")
157            .log_err()?;
158        let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
159        open_main_db::<M>(&db_path).await
160    })
161    .await;
162
163    if let Some(connection) = connection {
164        return connection;
165    }
166
167    // Set another static ref so that we can escalate the notification
168    ALL_FILE_DB_FAILED.store(true, Ordering::Release);
169
170    // If still failed, create an in memory db with a known name
171    open_fallback_db::<M>().await
172}
173
174async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
175    log::trace!("Opening database {}", db_path.display());
176    ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
177        .with_db_initialization_query(DB_INITIALIZE_QUERY)
178        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
179        .build()
180        .await
181        .log_err()
182}
183
184async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
185    log::warn!("Opening fallback in-memory database");
186    ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
187        .with_db_initialization_query(DB_INITIALIZE_QUERY)
188        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
189        .build()
190        .await
191        .expect(
192            "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
193        )
194}
195
196#[cfg(any(test, feature = "test-support"))]
197pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
198    use sqlez::thread_safe_connection::locking_queue;
199
200    ThreadSafeConnection::builder::<M>(db_name, false)
201        .with_db_initialization_query(DB_INITIALIZE_QUERY)
202        .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
203        // Serialize queued writes via a mutex and run them synchronously
204        .with_write_queue_constructor(locking_queue())
205        .build()
206        .await
207        .unwrap()
208}
209
210/// Implements a basic DB wrapper for a given domain
211///
212/// Arguments:
213/// - type of connection wrapper
214/// - dependencies, whose migrations should be run prior to this domain's migrations
215#[macro_export]
216macro_rules! static_connection {
217    ($t:ident, [ $($d:ty),* ]) => {
218        impl ::std::ops::Deref for $t {
219            type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
220
221            fn deref(&self) -> &Self::Target {
222                &self.0
223            }
224        }
225
226        impl ::std::clone::Clone for $t {
227            fn clone(&self) -> Self {
228                $t(self.0.clone())
229            }
230        }
231
232        impl $t {
233            /// Returns an instance backed by the per-App database if set,
234            /// or the shared fallback connection otherwise.
235            pub fn global(cx: &$crate::gpui::App) -> Self {
236                $t($crate::AppDatabase::global(cx).clone())
237            }
238
239            #[cfg(any(test, feature = "test-support"))]
240            pub async fn open_test_db(name: &'static str) -> Self {
241                $t($crate::open_test_db::<$t>(name).await)
242            }
243        }
244
245        $crate::inventory::submit! {
246            $crate::DomainMigration {
247                name: <$t as $crate::sqlez::domain::Domain>::NAME,
248                migrations: <$t as $crate::sqlez::domain::Domain>::MIGRATIONS,
249                dependencies: &[$(<$d as $crate::sqlez::domain::Domain>::NAME),*],
250                should_allow_migration_change: <$t as $crate::sqlez::domain::Domain>::should_allow_migration_change,
251            }
252        }
253    }
254}
255
256pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
257where
258    F: Future<Output = anyhow::Result<()>> + Send,
259{
260    cx.background_spawn(async move { db_write().await.log_err() })
261        .detach()
262}
263
264#[cfg(test)]
265mod tests {
266    use std::thread;
267
268    use sqlez::domain::Domain;
269    use sqlez_macros::sql;
270
271    use crate::open_db;
272
273    // Test bad migration panics
274    #[gpui::test]
275    #[should_panic]
276    async fn test_bad_migration_panics() {
277        enum BadDB {}
278
279        impl Domain for BadDB {
280            const NAME: &str = "db_tests";
281            const MIGRATIONS: &[&str] = &[
282                sql!(CREATE TABLE test(value);),
283                // failure because test already exists
284                sql!(CREATE TABLE test(value);),
285            ];
286        }
287
288        let tempdir = tempfile::Builder::new()
289            .prefix("DbTests")
290            .tempdir()
291            .unwrap();
292        let _bad_db = open_db::<BadDB>(
293            tempdir.path(),
294            release_channel::ReleaseChannel::Dev.dev_name(),
295        )
296        .await;
297    }
298
299    /// Test that DB exists but corrupted (causing recreate)
300    #[gpui::test]
301    async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
302        cx.executor().allow_parking();
303
304        enum CorruptedDB {}
305
306        impl Domain for CorruptedDB {
307            const NAME: &str = "db_tests";
308            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
309        }
310
311        enum GoodDB {}
312
313        impl Domain for GoodDB {
314            const NAME: &str = "db_tests"; //Notice same name
315            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
316        }
317
318        let tempdir = tempfile::Builder::new()
319            .prefix("DbTests")
320            .tempdir()
321            .unwrap();
322        {
323            let corrupt_db = open_db::<CorruptedDB>(
324                tempdir.path(),
325                release_channel::ReleaseChannel::Dev.dev_name(),
326            )
327            .await;
328            assert!(corrupt_db.persistent());
329        }
330
331        let good_db = open_db::<GoodDB>(
332            tempdir.path(),
333            release_channel::ReleaseChannel::Dev.dev_name(),
334        )
335        .await;
336        assert!(
337            good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
338                .unwrap()
339                .is_none()
340        );
341    }
342
343    /// Test that DB exists but corrupted (causing recreate)
344    #[gpui::test(iterations = 30)]
345    async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
346        cx.executor().allow_parking();
347
348        enum CorruptedDB {}
349
350        impl Domain for CorruptedDB {
351            const NAME: &str = "db_tests";
352
353            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
354        }
355
356        enum GoodDB {}
357
358        impl Domain for GoodDB {
359            const NAME: &str = "db_tests"; //Notice same name
360            const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
361        }
362
363        let tempdir = tempfile::Builder::new()
364            .prefix("DbTests")
365            .tempdir()
366            .unwrap();
367        {
368            // Setup the bad database
369            let corrupt_db = open_db::<CorruptedDB>(
370                tempdir.path(),
371                release_channel::ReleaseChannel::Dev.dev_name(),
372            )
373            .await;
374            assert!(corrupt_db.persistent());
375        }
376
377        // Try to connect to it a bunch of times at once
378        let mut guards = vec![];
379        for _ in 0..10 {
380            let tmp_path = tempdir.path().to_path_buf();
381            let guard = thread::spawn(move || {
382                let good_db = smol::block_on(open_db::<GoodDB>(
383                    tmp_path.as_path(),
384                    release_channel::ReleaseChannel::Dev.dev_name(),
385                ));
386                assert!(
387                    good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
388                        .unwrap()
389                        .is_none()
390                );
391            });
392
393            guards.push(guard);
394        }
395
396        for guard in guards.into_iter() {
397            assert!(guard.join().is_ok());
398        }
399    }
400}