1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context as _;
7use gpui::{App, AppContext};
8pub use indoc::indoc;
9pub use paths::database_dir;
10pub use smol;
11pub use sqlez;
12pub use sqlez_macros;
13
14pub use release_channel::RELEASE_CHANNEL;
15use sqlez::domain::Migrator;
16use sqlez::thread_safe_connection::ThreadSafeConnection;
17use sqlez_macros::sql;
18use std::future::Future;
19use std::path::Path;
20use std::sync::{LazyLock, atomic::Ordering};
21use std::{env, sync::atomic::AtomicBool};
22use util::{ResultExt, maybe};
23
24const CONNECTION_INITIALIZE_QUERY: &str = sql!(
25 PRAGMA foreign_keys=TRUE;
26);
27
28const DB_INITIALIZE_QUERY: &str = sql!(
29 PRAGMA journal_mode=WAL;
30 PRAGMA busy_timeout=1;
31 PRAGMA case_sensitive_like=TRUE;
32 PRAGMA synchronous=NORMAL;
33);
34
35const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
36
37const DB_FILE_NAME: &str = "db.sqlite";
38
39pub static ZED_STATELESS: LazyLock<bool> =
40 LazyLock::new(|| env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
41
42pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
43
44/// Open or create a database at the given directory path.
45/// This will retry a couple times if there are failures. If opening fails once, the db directory
46/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
47/// In either case, static variables are set so that the user can be notified.
48pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection<M> {
49 if *ZED_STATELESS {
50 return open_fallback_db().await;
51 }
52
53 let main_db_dir = db_dir.join(format!("0-{}", scope));
54
55 let connection = maybe!(async {
56 smol::fs::create_dir_all(&main_db_dir)
57 .await
58 .context("Could not create db directory")
59 .log_err()?;
60 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
61 open_main_db(&db_path).await
62 })
63 .await;
64
65 if let Some(connection) = connection {
66 return connection;
67 }
68
69 // Set another static ref so that we can escalate the notification
70 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
71
72 // If still failed, create an in memory db with a known name
73 open_fallback_db().await
74}
75
76async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection<M>> {
77 log::info!("Opening main db");
78 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
79 .with_db_initialization_query(DB_INITIALIZE_QUERY)
80 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
81 .build()
82 .await
83 .log_err()
84}
85
86async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
87 log::info!("Opening fallback db");
88 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
89 .with_db_initialization_query(DB_INITIALIZE_QUERY)
90 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
91 .build()
92 .await
93 .expect(
94 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
95 )
96}
97
98#[cfg(any(test, feature = "test-support"))]
99pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
100 use sqlez::thread_safe_connection::locking_queue;
101
102 ThreadSafeConnection::<M>::builder(db_name, false)
103 .with_db_initialization_query(DB_INITIALIZE_QUERY)
104 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
105 // Serialize queued writes via a mutex and run them synchronously
106 .with_write_queue_constructor(locking_queue())
107 .build()
108 .await
109 .unwrap()
110}
111
112/// Implements a basic DB wrapper for a given domain
113#[macro_export]
114macro_rules! define_connection {
115 (pub static ref $id:ident: $t:ident<()> = $migrations:expr; $($global:ident)?) => {
116 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
117
118 impl ::std::ops::Deref for $t {
119 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
120
121 fn deref(&self) -> &Self::Target {
122 &self.0
123 }
124 }
125
126 impl $crate::sqlez::domain::Domain for $t {
127 fn name() -> &'static str {
128 stringify!($t)
129 }
130
131 fn migrations() -> &'static [&'static str] {
132 $migrations
133 }
134 }
135
136 #[cfg(any(test, feature = "test-support"))]
137 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
138 $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
139 });
140
141 #[cfg(not(any(test, feature = "test-support")))]
142 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
143 let db_dir = $crate::database_dir();
144 let scope = if false $(|| stringify!($global) == "global")? {
145 "global"
146 } else {
147 $crate::RELEASE_CHANNEL.dev_name()
148 };
149 $t($crate::smol::block_on($crate::open_db(db_dir, scope)))
150 });
151 };
152 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr; $($global:ident)?) => {
153 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
154
155 impl ::std::ops::Deref for $t {
156 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
157
158 fn deref(&self) -> &Self::Target {
159 &self.0
160 }
161 }
162
163 impl $crate::sqlez::domain::Domain for $t {
164 fn name() -> &'static str {
165 stringify!($t)
166 }
167
168 fn migrations() -> &'static [&'static str] {
169 $migrations
170 }
171 }
172
173 #[cfg(any(test, feature = "test-support"))]
174 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
175 $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
176 });
177
178 #[cfg(not(any(test, feature = "test-support")))]
179 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
180 let db_dir = $crate::database_dir();
181 let scope = if false $(|| stringify!($global) == "global")? {
182 "global"
183 } else {
184 $crate::RELEASE_CHANNEL.dev_name()
185 };
186 $t($crate::smol::block_on($crate::open_db(db_dir, scope)))
187 });
188 };
189}
190
191pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
192where
193 F: Future<Output = anyhow::Result<()>> + Send,
194{
195 cx.background_spawn(async move { db_write().await.log_err() })
196 .detach()
197}
198
199#[cfg(test)]
200mod tests {
201 use std::thread;
202
203 use sqlez::domain::Domain;
204 use sqlez_macros::sql;
205
206 use crate::open_db;
207
208 // Test bad migration panics
209 #[gpui::test]
210 #[should_panic]
211 async fn test_bad_migration_panics() {
212 enum BadDB {}
213
214 impl Domain for BadDB {
215 fn name() -> &'static str {
216 "db_tests"
217 }
218
219 fn migrations() -> &'static [&'static str] {
220 &[
221 sql!(CREATE TABLE test(value);),
222 // failure because test already exists
223 sql!(CREATE TABLE test(value);),
224 ]
225 }
226 }
227
228 let tempdir = tempfile::Builder::new()
229 .prefix("DbTests")
230 .tempdir()
231 .unwrap();
232 let _bad_db = open_db::<BadDB>(
233 tempdir.path(),
234 &release_channel::ReleaseChannel::Dev.dev_name(),
235 )
236 .await;
237 }
238
239 /// Test that DB exists but corrupted (causing recreate)
240 #[gpui::test]
241 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
242 cx.executor().allow_parking();
243
244 enum CorruptedDB {}
245
246 impl Domain for CorruptedDB {
247 fn name() -> &'static str {
248 "db_tests"
249 }
250
251 fn migrations() -> &'static [&'static str] {
252 &[sql!(CREATE TABLE test(value);)]
253 }
254 }
255
256 enum GoodDB {}
257
258 impl Domain for GoodDB {
259 fn name() -> &'static str {
260 "db_tests" //Notice same name
261 }
262
263 fn migrations() -> &'static [&'static str] {
264 &[sql!(CREATE TABLE test2(value);)] //But different migration
265 }
266 }
267
268 let tempdir = tempfile::Builder::new()
269 .prefix("DbTests")
270 .tempdir()
271 .unwrap();
272 {
273 let corrupt_db = open_db::<CorruptedDB>(
274 tempdir.path(),
275 &release_channel::ReleaseChannel::Dev.dev_name(),
276 )
277 .await;
278 assert!(corrupt_db.persistent());
279 }
280
281 let good_db = open_db::<GoodDB>(
282 tempdir.path(),
283 &release_channel::ReleaseChannel::Dev.dev_name(),
284 )
285 .await;
286 assert!(
287 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
288 .unwrap()
289 .is_none()
290 );
291 }
292
293 /// Test that DB exists but corrupted (causing recreate)
294 #[gpui::test(iterations = 30)]
295 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
296 cx.executor().allow_parking();
297
298 enum CorruptedDB {}
299
300 impl Domain for CorruptedDB {
301 fn name() -> &'static str {
302 "db_tests"
303 }
304
305 fn migrations() -> &'static [&'static str] {
306 &[sql!(CREATE TABLE test(value);)]
307 }
308 }
309
310 enum GoodDB {}
311
312 impl Domain for GoodDB {
313 fn name() -> &'static str {
314 "db_tests" //Notice same name
315 }
316
317 fn migrations() -> &'static [&'static str] {
318 &[sql!(CREATE TABLE test2(value);)] //But different migration
319 }
320 }
321
322 let tempdir = tempfile::Builder::new()
323 .prefix("DbTests")
324 .tempdir()
325 .unwrap();
326 {
327 // Setup the bad database
328 let corrupt_db = open_db::<CorruptedDB>(
329 tempdir.path(),
330 &release_channel::ReleaseChannel::Dev.dev_name(),
331 )
332 .await;
333 assert!(corrupt_db.persistent());
334 }
335
336 // Try to connect to it a bunch of times at once
337 let mut guards = vec![];
338 for _ in 0..10 {
339 let tmp_path = tempdir.path().to_path_buf();
340 let guard = thread::spawn(move || {
341 let good_db = smol::block_on(open_db::<GoodDB>(
342 tmp_path.as_path(),
343 &release_channel::ReleaseChannel::Dev.dev_name(),
344 ));
345 assert!(
346 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
347 .unwrap()
348 .is_none()
349 );
350 });
351
352 guards.push(guard);
353 }
354
355 for guard in guards.into_iter() {
356 assert!(guard.join().is_ok());
357 }
358 }
359}