1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context;
7use gpui::AppContext;
8pub use indoc::indoc;
9pub use lazy_static;
10pub use paths::database_dir;
11pub use smol;
12pub use sqlez;
13pub use sqlez_macros;
14
15use release_channel::ReleaseChannel;
16pub use release_channel::RELEASE_CHANNEL;
17use sqlez::domain::Migrator;
18use sqlez::thread_safe_connection::ThreadSafeConnection;
19use sqlez_macros::sql;
20use std::future::Future;
21use std::path::{Path, PathBuf};
22use std::sync::atomic::{AtomicBool, Ordering};
23use util::{maybe, ResultExt};
24
25const CONNECTION_INITIALIZE_QUERY: &str = sql!(
26 PRAGMA foreign_keys=TRUE;
27);
28
29const DB_INITIALIZE_QUERY: &str = sql!(
30 PRAGMA journal_mode=WAL;
31 PRAGMA busy_timeout=1;
32 PRAGMA case_sensitive_like=TRUE;
33 PRAGMA synchronous=NORMAL;
34);
35
36const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
37
38const DB_FILE_NAME: &str = "db.sqlite";
39
40lazy_static::lazy_static! {
41 pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
42 pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
43}
44
45/// Open or create a database at the given directory path.
46/// This will retry a couple times if there are failures. If opening fails once, the db directory
47/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
48/// In either case, static variables are set so that the user can be notified.
49pub async fn open_db<M: Migrator + 'static>(
50 db_dir: &Path,
51 release_channel: &ReleaseChannel,
52) -> ThreadSafeConnection<M> {
53 if *ZED_STATELESS {
54 return open_fallback_db().await;
55 }
56
57 let release_channel_name = release_channel.dev_name();
58 let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
59
60 let connection = maybe!(async {
61 smol::fs::create_dir_all(&main_db_dir)
62 .await
63 .context("Could not create db directory")
64 .log_err()?;
65 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
66 open_main_db(&db_path).await
67 })
68 .await;
69
70 if let Some(connection) = connection {
71 return connection;
72 }
73
74 // Set another static ref so that we can escalate the notification
75 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
76
77 // If still failed, create an in memory db with a known name
78 open_fallback_db().await
79}
80
81async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
82 log::info!("Opening main db");
83 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
84 .with_db_initialization_query(DB_INITIALIZE_QUERY)
85 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
86 .build()
87 .await
88 .log_err()
89}
90
91async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
92 log::info!("Opening fallback db");
93 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
94 .with_db_initialization_query(DB_INITIALIZE_QUERY)
95 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
96 .build()
97 .await
98 .expect(
99 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
100 )
101}
102
103#[cfg(any(test, feature = "test-support"))]
104pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
105 use sqlez::thread_safe_connection::locking_queue;
106
107 ThreadSafeConnection::<M>::builder(db_name, false)
108 .with_db_initialization_query(DB_INITIALIZE_QUERY)
109 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
110 // Serialize queued writes via a mutex and run them synchronously
111 .with_write_queue_constructor(locking_queue())
112 .build()
113 .await
114 .unwrap()
115}
116
117/// Implements a basic DB wrapper for a given domain
118#[macro_export]
119macro_rules! define_connection {
120 (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
121 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
122
123 impl ::std::ops::Deref for $t {
124 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
125
126 fn deref(&self) -> &Self::Target {
127 &self.0
128 }
129 }
130
131 impl $crate::sqlez::domain::Domain for $t {
132 fn name() -> &'static str {
133 stringify!($t)
134 }
135
136 fn migrations() -> &'static [&'static str] {
137 $migrations
138 }
139 }
140
141 #[cfg(any(test, feature = "test-support"))]
142 $crate::lazy_static::lazy_static! {
143 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
144 }
145
146 #[cfg(not(any(test, feature = "test-support")))]
147 $crate::lazy_static::lazy_static! {
148 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)));
149 }
150 };
151 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
152 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
153
154 impl ::std::ops::Deref for $t {
155 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
156
157 fn deref(&self) -> &Self::Target {
158 &self.0
159 }
160 }
161
162 impl $crate::sqlez::domain::Domain for $t {
163 fn name() -> &'static str {
164 stringify!($t)
165 }
166
167 fn migrations() -> &'static [&'static str] {
168 $migrations
169 }
170 }
171
172 #[cfg(any(test, feature = "test-support"))]
173 $crate::lazy_static::lazy_static! {
174 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
175 }
176
177 #[cfg(not(any(test, feature = "test-support")))]
178 $crate::lazy_static::lazy_static! {
179 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)));
180 }
181 };
182}
183
184pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
185where
186 F: Future<Output = anyhow::Result<()>> + Send,
187{
188 cx.background_executor()
189 .spawn(async move { db_write().await.log_err() })
190 .detach()
191}
192
193#[cfg(test)]
194mod tests {
195 use std::thread;
196
197 use sqlez::domain::Domain;
198 use sqlez_macros::sql;
199
200 use crate::open_db;
201
202 // Test bad migration panics
203 #[gpui::test]
204 #[should_panic]
205 async fn test_bad_migration_panics() {
206 enum BadDB {}
207
208 impl Domain for BadDB {
209 fn name() -> &'static str {
210 "db_tests"
211 }
212
213 fn migrations() -> &'static [&'static str] {
214 &[
215 sql!(CREATE TABLE test(value);),
216 // failure because test already exists
217 sql!(CREATE TABLE test(value);),
218 ]
219 }
220 }
221
222 let tempdir = tempfile::Builder::new()
223 .prefix("DbTests")
224 .tempdir()
225 .unwrap();
226 let _bad_db = open_db::<BadDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
227 }
228
229 /// Test that DB exists but corrupted (causing recreate)
230 #[gpui::test]
231 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
232 cx.executor().allow_parking();
233
234 enum CorruptedDB {}
235
236 impl Domain for CorruptedDB {
237 fn name() -> &'static str {
238 "db_tests"
239 }
240
241 fn migrations() -> &'static [&'static str] {
242 &[sql!(CREATE TABLE test(value);)]
243 }
244 }
245
246 enum GoodDB {}
247
248 impl Domain for GoodDB {
249 fn name() -> &'static str {
250 "db_tests" //Notice same name
251 }
252
253 fn migrations() -> &'static [&'static str] {
254 &[sql!(CREATE TABLE test2(value);)] //But different migration
255 }
256 }
257
258 let tempdir = tempfile::Builder::new()
259 .prefix("DbTests")
260 .tempdir()
261 .unwrap();
262 {
263 let corrupt_db =
264 open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
265 assert!(corrupt_db.persistent());
266 }
267
268 let good_db =
269 open_db::<GoodDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
270 assert!(
271 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
272 .unwrap()
273 .is_none()
274 );
275 }
276
277 /// Test that DB exists but corrupted (causing recreate)
278 #[gpui::test(iterations = 30)]
279 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
280 cx.executor().allow_parking();
281
282 enum CorruptedDB {}
283
284 impl Domain for CorruptedDB {
285 fn name() -> &'static str {
286 "db_tests"
287 }
288
289 fn migrations() -> &'static [&'static str] {
290 &[sql!(CREATE TABLE test(value);)]
291 }
292 }
293
294 enum GoodDB {}
295
296 impl Domain for GoodDB {
297 fn name() -> &'static str {
298 "db_tests" //Notice same name
299 }
300
301 fn migrations() -> &'static [&'static str] {
302 &[sql!(CREATE TABLE test2(value);)] //But different migration
303 }
304 }
305
306 let tempdir = tempfile::Builder::new()
307 .prefix("DbTests")
308 .tempdir()
309 .unwrap();
310 {
311 // Setup the bad database
312 let corrupt_db =
313 open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
314 assert!(corrupt_db.persistent());
315 }
316
317 // Try to connect to it a bunch of times at once
318 let mut guards = vec![];
319 for _ in 0..10 {
320 let tmp_path = tempdir.path().to_path_buf();
321 let guard = thread::spawn(move || {
322 let good_db = smol::block_on(open_db::<GoodDB>(
323 tmp_path.as_path(),
324 &release_channel::ReleaseChannel::Dev,
325 ));
326 assert!(
327 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
328 .unwrap()
329 .is_none()
330 );
331 });
332
333 guards.push(guard);
334 }
335
336 for guard in guards.into_iter() {
337 assert!(guard.join().is_ok());
338 }
339 }
340}