1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context as _;
7pub use gpui;
8use gpui::{App, AppContext, Global};
9pub use indoc::indoc;
10pub use inventory;
11pub use paths::database_dir;
12pub use smol;
13pub use sqlez;
14pub use sqlez_macros;
15pub use uuid;
16
17pub use release_channel::RELEASE_CHANNEL;
18use sqlez::domain::Migrator;
19use sqlez::thread_safe_connection::ThreadSafeConnection;
20use sqlez_macros::sql;
21use std::future::Future;
22use std::path::Path;
23use std::sync::atomic::AtomicBool;
24use std::sync::{LazyLock, atomic::Ordering};
25use util::{ResultExt, maybe};
26use zed_env_vars::ZED_STATELESS;
27
28/// A migration registered via `static_connection!` and collected at link time.
29pub struct DomainMigration {
30 pub name: &'static str,
31 pub migrations: &'static [&'static str],
32 pub dependencies: &'static [&'static str],
33 pub should_allow_migration_change: fn(usize, &str, &str) -> bool,
34}
35
36inventory::collect!(DomainMigration);
37
38/// The shared database connection backing all domain-specific DB wrappers.
39/// Set as a GPUI global per-App. Falls back to a shared LazyLock if not set.
40pub struct AppDatabase(pub ThreadSafeConnection);
41
42impl Global for AppDatabase {}
43
44/// Migrator that runs all inventory-registered domain migrations.
45pub struct AppMigrator;
46
47impl Migrator for AppMigrator {
48 fn migrate(connection: &sqlez::connection::Connection) -> anyhow::Result<()> {
49 let registrations: Vec<&DomainMigration> = inventory::iter::<DomainMigration>().collect();
50 let sorted = topological_sort(®istrations);
51 for reg in &sorted {
52 let mut should_allow = reg.should_allow_migration_change;
53 connection.migrate(reg.name, reg.migrations, &mut should_allow)?;
54 }
55 Ok(())
56 }
57}
58
59impl AppDatabase {
60 /// Opens the production database and runs all inventory-registered
61 /// migrations in dependency order.
62 pub fn new() -> Self {
63 let db_dir = database_dir();
64 let scope = RELEASE_CHANNEL.dev_name();
65 let connection = smol::block_on(open_db::<AppMigrator>(db_dir, scope));
66 Self(connection)
67 }
68
69 /// Creates a new in-memory database with a unique name and runs all
70 /// inventory-registered migrations in dependency order.
71 #[cfg(any(test, feature = "test-support"))]
72 pub fn test_new() -> Self {
73 let name = format!("test-db-{}", uuid::Uuid::new_v4());
74 let connection = smol::block_on(open_test_db::<AppMigrator>(&name));
75 Self(connection)
76 }
77
78 /// Returns the per-App connection if set, otherwise falls back to
79 /// the shared LazyLock.
80 pub fn global(cx: &App) -> &ThreadSafeConnection {
81 #[allow(unreachable_code)]
82 if let Some(db) = cx.try_global::<Self>() {
83 return &db.0;
84 } else {
85 #[cfg(any(feature = "test-support", test))]
86 return &TEST_APP_DATABASE.0;
87
88 panic!("database not initialized")
89 }
90 }
91}
92
93fn topological_sort<'a>(registrations: &[&'a DomainMigration]) -> Vec<&'a DomainMigration> {
94 let mut sorted: Vec<&DomainMigration> = Vec::new();
95 let mut visited: std::collections::HashSet<&str> = std::collections::HashSet::new();
96
97 fn visit<'a>(
98 name: &str,
99 registrations: &[&'a DomainMigration],
100 sorted: &mut Vec<&'a DomainMigration>,
101 visited: &mut std::collections::HashSet<&'a str>,
102 ) {
103 if visited.contains(name) {
104 return;
105 }
106 if let Some(reg) = registrations.iter().find(|r| r.name == name) {
107 for dep in reg.dependencies {
108 visit(dep, registrations, sorted, visited);
109 }
110 visited.insert(reg.name);
111 sorted.push(reg);
112 }
113 }
114
115 for reg in registrations {
116 visit(reg.name, registrations, &mut sorted, &mut visited);
117 }
118 sorted
119}
120
121/// Shared fallback `AppDatabase` used when no per-App global is set.
122#[cfg(any(test, feature = "test-support"))]
123static TEST_APP_DATABASE: LazyLock<AppDatabase> = LazyLock::new(AppDatabase::test_new);
124
125const CONNECTION_INITIALIZE_QUERY: &str = sql!(
126 PRAGMA foreign_keys=TRUE;
127);
128
129const DB_INITIALIZE_QUERY: &str = sql!(
130 PRAGMA journal_mode=WAL;
131 PRAGMA busy_timeout=500;
132 PRAGMA case_sensitive_like=TRUE;
133 PRAGMA synchronous=NORMAL;
134);
135
136const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
137
138const DB_FILE_NAME: &str = "db.sqlite";
139
140pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
141
142/// Open or create a database at the given directory path.
143/// This will retry a couple times if there are failures. If opening fails once, the db directory
144/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
145/// In either case, static variables are set so that the user can be notified.
146pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
147 if *ZED_STATELESS {
148 return open_fallback_db::<M>().await;
149 }
150
151 let main_db_dir = db_dir.join(format!("0-{}", scope));
152
153 let connection = maybe!(async {
154 smol::fs::create_dir_all(&main_db_dir)
155 .await
156 .context("Could not create db directory")
157 .log_err()?;
158 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
159 open_main_db::<M>(&db_path).await
160 })
161 .await;
162
163 if let Some(connection) = connection {
164 return connection;
165 }
166
167 // Set another static ref so that we can escalate the notification
168 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
169
170 // If still failed, create an in memory db with a known name
171 open_fallback_db::<M>().await
172}
173
174async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
175 log::trace!("Opening database {}", db_path.display());
176 ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
177 .with_db_initialization_query(DB_INITIALIZE_QUERY)
178 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
179 .build()
180 .await
181 .log_err()
182}
183
184async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
185 log::warn!("Opening fallback in-memory database");
186 ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
187 .with_db_initialization_query(DB_INITIALIZE_QUERY)
188 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
189 .build()
190 .await
191 .expect(
192 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
193 )
194}
195
196#[cfg(any(test, feature = "test-support"))]
197pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
198 use sqlez::thread_safe_connection::locking_queue;
199
200 ThreadSafeConnection::builder::<M>(db_name, false)
201 .with_db_initialization_query(DB_INITIALIZE_QUERY)
202 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
203 // Serialize queued writes via a mutex and run them synchronously
204 .with_write_queue_constructor(locking_queue())
205 .build()
206 .await
207 .unwrap()
208}
209
210/// Implements a basic DB wrapper for a given domain
211///
212/// Arguments:
213/// - type of connection wrapper
214/// - dependencies, whose migrations should be run prior to this domain's migrations
215#[macro_export]
216macro_rules! static_connection {
217 ($t:ident, [ $($d:ty),* ]) => {
218 impl ::std::ops::Deref for $t {
219 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
220
221 fn deref(&self) -> &Self::Target {
222 &self.0
223 }
224 }
225
226 impl ::std::clone::Clone for $t {
227 fn clone(&self) -> Self {
228 $t(self.0.clone())
229 }
230 }
231
232 impl $t {
233 /// Returns an instance backed by the per-App database if set,
234 /// or the shared fallback connection otherwise.
235 pub fn global(cx: &$crate::gpui::App) -> Self {
236 $t($crate::AppDatabase::global(cx).clone())
237 }
238
239 #[cfg(any(test, feature = "test-support"))]
240 pub async fn open_test_db(name: &'static str) -> Self {
241 $t($crate::open_test_db::<$t>(name).await)
242 }
243 }
244
245 $crate::inventory::submit! {
246 $crate::DomainMigration {
247 name: <$t as $crate::sqlez::domain::Domain>::NAME,
248 migrations: <$t as $crate::sqlez::domain::Domain>::MIGRATIONS,
249 dependencies: &[$(<$d as $crate::sqlez::domain::Domain>::NAME),*],
250 should_allow_migration_change: <$t as $crate::sqlez::domain::Domain>::should_allow_migration_change,
251 }
252 }
253 }
254}
255
256pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
257where
258 F: Future<Output = anyhow::Result<()>> + Send,
259{
260 cx.background_spawn(async move { db_write().await.log_err() })
261 .detach()
262}
263
264#[cfg(test)]
265mod tests {
266 use std::thread;
267
268 use sqlez::domain::Domain;
269 use sqlez_macros::sql;
270
271 use crate::open_db;
272
273 // Test bad migration panics
274 #[gpui::test]
275 #[should_panic]
276 async fn test_bad_migration_panics() {
277 enum BadDB {}
278
279 impl Domain for BadDB {
280 const NAME: &str = "db_tests";
281 const MIGRATIONS: &[&str] = &[
282 sql!(CREATE TABLE test(value);),
283 // failure because test already exists
284 sql!(CREATE TABLE test(value);),
285 ];
286 }
287
288 let tempdir = tempfile::Builder::new()
289 .prefix("DbTests")
290 .tempdir()
291 .unwrap();
292 let _bad_db = open_db::<BadDB>(
293 tempdir.path(),
294 release_channel::ReleaseChannel::Dev.dev_name(),
295 )
296 .await;
297 }
298
299 /// Test that DB exists but corrupted (causing recreate)
300 #[gpui::test]
301 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
302 cx.executor().allow_parking();
303
304 enum CorruptedDB {}
305
306 impl Domain for CorruptedDB {
307 const NAME: &str = "db_tests";
308 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
309 }
310
311 enum GoodDB {}
312
313 impl Domain for GoodDB {
314 const NAME: &str = "db_tests"; //Notice same name
315 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
316 }
317
318 let tempdir = tempfile::Builder::new()
319 .prefix("DbTests")
320 .tempdir()
321 .unwrap();
322 {
323 let corrupt_db = open_db::<CorruptedDB>(
324 tempdir.path(),
325 release_channel::ReleaseChannel::Dev.dev_name(),
326 )
327 .await;
328 assert!(corrupt_db.persistent());
329 }
330
331 let good_db = open_db::<GoodDB>(
332 tempdir.path(),
333 release_channel::ReleaseChannel::Dev.dev_name(),
334 )
335 .await;
336 assert!(
337 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
338 .unwrap()
339 .is_none()
340 );
341 }
342
343 /// Test that DB exists but corrupted (causing recreate)
344 #[gpui::test(iterations = 30)]
345 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
346 cx.executor().allow_parking();
347
348 enum CorruptedDB {}
349
350 impl Domain for CorruptedDB {
351 const NAME: &str = "db_tests";
352
353 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
354 }
355
356 enum GoodDB {}
357
358 impl Domain for GoodDB {
359 const NAME: &str = "db_tests"; //Notice same name
360 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
361 }
362
363 let tempdir = tempfile::Builder::new()
364 .prefix("DbTests")
365 .tempdir()
366 .unwrap();
367 {
368 // Setup the bad database
369 let corrupt_db = open_db::<CorruptedDB>(
370 tempdir.path(),
371 release_channel::ReleaseChannel::Dev.dev_name(),
372 )
373 .await;
374 assert!(corrupt_db.persistent());
375 }
376
377 // Try to connect to it a bunch of times at once
378 let mut guards = vec![];
379 for _ in 0..10 {
380 let tmp_path = tempdir.path().to_path_buf();
381 let guard = thread::spawn(move || {
382 let good_db = smol::block_on(open_db::<GoodDB>(
383 tmp_path.as_path(),
384 release_channel::ReleaseChannel::Dev.dev_name(),
385 ));
386 assert!(
387 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
388 .unwrap()
389 .is_none()
390 );
391 });
392
393 guards.push(guard);
394 }
395
396 for guard in guards.into_iter() {
397 assert!(guard.join().is_ok());
398 }
399 }
400}