1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context;
7use gpui::AppContext;
8pub use indoc::indoc;
9pub use lazy_static;
10use parking_lot::{Mutex, RwLock};
11pub use smol;
12pub use sqlez;
13pub use sqlez_macros;
14pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
15pub use util::paths::DB_DIR;
16
17use sqlez::domain::Migrator;
18use sqlez::thread_safe_connection::ThreadSafeConnection;
19use sqlez_macros::sql;
20use std::fs::create_dir_all;
21use std::future::Future;
22use std::path::{Path, PathBuf};
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::time::{SystemTime, UNIX_EPOCH};
25use util::channel::ReleaseChannel;
26use util::{async_iife, ResultExt};
27
28const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
29 PRAGMA foreign_keys=TRUE;
30);
31
32const DB_INITIALIZE_QUERY: &'static str = sql!(
33 PRAGMA journal_mode=WAL;
34 PRAGMA busy_timeout=1;
35 PRAGMA case_sensitive_like=TRUE;
36 PRAGMA synchronous=NORMAL;
37);
38
39const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
40
41const DB_FILE_NAME: &'static str = "db.sqlite";
42
43lazy_static::lazy_static! {
44 // !!!!!!! CHANGE BACK TO DEFAULT FALSE BEFORE SHIPPING
45 static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
46 pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
47 pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
48}
49static DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
50
51/// Open or create a database at the given directory path.
52/// This will retry a couple times if there are failures. If opening fails once, the db directory
53/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
54/// In either case, static variables are set so that the user can be notified.
55pub async fn open_db<M: Migrator + 'static>(
56 db_dir: &Path,
57 release_channel: &ReleaseChannel,
58) -> ThreadSafeConnection<M> {
59 if *ZED_STATELESS {
60 return open_fallback_db().await;
61 }
62
63 let release_channel_name = release_channel.dev_name();
64 let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
65
66 let connection = async_iife!({
67 // Note: This still has a race condition where 1 set of migrations succeeds
68 // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
69 // This will cause the first connection to have the database taken out
70 // from under it. This *should* be fine though. The second dabatase failure will
71 // cause errors in the log and so should be observed by developers while writing
72 // soon-to-be good migrations. If user databases are corrupted, we toss them out
73 // and try again from a blank. As long as running all migrations from start to end
74 // on a blank database is ok, this race condition will never be triggered.
75 //
76 // Basically: Don't ever push invalid migrations to stable or everyone will have
77 // a bad time.
78
79 // If no db folder, create one at 0-{channel}
80 create_dir_all(&main_db_dir).context("Could not create db directory")?;
81 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
82
83 // Optimistically open databases in parallel
84 if !DB_FILE_OPERATIONS.is_locked() {
85 // Try building a connection
86 if let Some(connection) = open_main_db(&db_path).await {
87 return Ok(connection)
88 };
89 }
90
91 // Take a lock in the failure case so that we move the db once per process instead
92 // of potentially multiple times from different threads. This shouldn't happen in the
93 // normal path
94 let _lock = DB_FILE_OPERATIONS.lock();
95 if let Some(connection) = open_main_db(&db_path).await {
96 return Ok(connection)
97 };
98
99 let backup_timestamp = SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
102 .as_millis();
103
104 // If failed, move 0-{channel} to {current unix timestamp}-{channel}
105 let backup_db_dir = db_dir.join(Path::new(&format!(
106 "{}-{}",
107 backup_timestamp,
108 release_channel_name,
109 )));
110
111 std::fs::rename(&main_db_dir, &backup_db_dir)
112 .context("Failed clean up corrupted database, panicking.")?;
113
114 // Set a static ref with the failed timestamp and error so we can notify the user
115 {
116 let mut guard = BACKUP_DB_PATH.write();
117 *guard = Some(backup_db_dir);
118 }
119
120 // Create a new 0-{channel}
121 create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
122 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
123
124 // Try again
125 open_main_db(&db_path).await.context("Could not newly created db")
126 }).await.log_err();
127
128 if let Some(connection) = connection {
129 return connection;
130 }
131
132 // Set another static ref so that we can escalate the notification
133 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
134
135 // If still failed, create an in memory db with a known name
136 open_fallback_db().await
137}
138
139async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
140 log::info!("Opening main db");
141 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
142 .with_db_initialization_query(DB_INITIALIZE_QUERY)
143 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
144 .build()
145 .await
146 .log_err()
147}
148
149async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
150 log::info!("Opening fallback db");
151 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
152 .with_db_initialization_query(DB_INITIALIZE_QUERY)
153 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
154 .build()
155 .await
156 .expect(
157 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
158 )
159}
160
161#[cfg(any(test, feature = "test-support"))]
162pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
163 use sqlez::thread_safe_connection::locking_queue;
164
165 ThreadSafeConnection::<M>::builder(db_name, false)
166 .with_db_initialization_query(DB_INITIALIZE_QUERY)
167 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
168 // Serialize queued writes via a mutex and run them synchronously
169 .with_write_queue_constructor(locking_queue())
170 .build()
171 .await
172 .unwrap()
173}
174
175/// Implements a basic DB wrapper for a given domain
176#[macro_export]
177macro_rules! define_connection {
178 (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
179 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
180
181 impl ::std::ops::Deref for $t {
182 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
183
184 fn deref(&self) -> &Self::Target {
185 &self.0
186 }
187 }
188
189 impl $crate::sqlez::domain::Domain for $t {
190 fn name() -> &'static str {
191 stringify!($t)
192 }
193
194 fn migrations() -> &'static [&'static str] {
195 $migrations
196 }
197 }
198
199 #[cfg(any(test, feature = "test-support"))]
200 $crate::lazy_static::lazy_static! {
201 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
202 }
203
204 #[cfg(not(any(test, feature = "test-support")))]
205 $crate::lazy_static::lazy_static! {
206 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
207 }
208 };
209 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
210 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
211
212 impl ::std::ops::Deref for $t {
213 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
214
215 fn deref(&self) -> &Self::Target {
216 &self.0
217 }
218 }
219
220 impl $crate::sqlez::domain::Domain for $t {
221 fn name() -> &'static str {
222 stringify!($t)
223 }
224
225 fn migrations() -> &'static [&'static str] {
226 $migrations
227 }
228 }
229
230 #[cfg(any(test, feature = "test-support"))]
231 $crate::lazy_static::lazy_static! {
232 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
233 }
234
235 #[cfg(not(any(test, feature = "test-support")))]
236 $crate::lazy_static::lazy_static! {
237 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
238 }
239 };
240}
241
242pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
243where
244 F: Future<Output = anyhow::Result<()>> + Send,
245{
246 cx.background()
247 .spawn(async move { db_write().await.log_err() })
248 .detach()
249}
250
251#[cfg(test)]
252mod tests {
253 use std::{fs, thread};
254
255 use sqlez::{connection::Connection, domain::Domain};
256 use sqlez_macros::sql;
257 use tempdir::TempDir;
258
259 use crate::{open_db, DB_FILE_NAME};
260
261 // Test bad migration panics
262 #[gpui::test]
263 #[should_panic]
264 async fn test_bad_migration_panics() {
265 enum BadDB {}
266
267 impl Domain for BadDB {
268 fn name() -> &'static str {
269 "db_tests"
270 }
271
272 fn migrations() -> &'static [&'static str] {
273 &[
274 sql!(CREATE TABLE test(value);),
275 // failure because test already exists
276 sql!(CREATE TABLE test(value);),
277 ]
278 }
279 }
280
281 let tempdir = TempDir::new("DbTests").unwrap();
282 let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
283 }
284
285 /// Test that DB exists but corrupted (causing recreate)
286 #[gpui::test]
287 async fn test_db_corruption() {
288 enum CorruptedDB {}
289
290 impl Domain for CorruptedDB {
291 fn name() -> &'static str {
292 "db_tests"
293 }
294
295 fn migrations() -> &'static [&'static str] {
296 &[sql!(CREATE TABLE test(value);)]
297 }
298 }
299
300 enum GoodDB {}
301
302 impl Domain for GoodDB {
303 fn name() -> &'static str {
304 "db_tests" //Notice same name
305 }
306
307 fn migrations() -> &'static [&'static str] {
308 &[sql!(CREATE TABLE test2(value);)] //But different migration
309 }
310 }
311
312 let tempdir = TempDir::new("DbTests").unwrap();
313 {
314 let corrupt_db =
315 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
316 assert!(corrupt_db.persistent());
317 }
318
319 let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
320 assert!(
321 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
322 .unwrap()
323 .is_none()
324 );
325
326 let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
327 .unwrap()
328 .find(|entry| {
329 !entry
330 .as_ref()
331 .unwrap()
332 .file_name()
333 .to_str()
334 .unwrap()
335 .starts_with("0")
336 })
337 .unwrap()
338 .unwrap()
339 .path();
340 corrupted_backup_dir.push(DB_FILE_NAME);
341
342 let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
343 assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
344 .unwrap()
345 .is_none());
346 }
347
348 /// Test that DB exists but corrupted (causing recreate)
349 #[gpui::test]
350 async fn test_simultaneous_db_corruption() {
351 enum CorruptedDB {}
352
353 impl Domain for CorruptedDB {
354 fn name() -> &'static str {
355 "db_tests"
356 }
357
358 fn migrations() -> &'static [&'static str] {
359 &[sql!(CREATE TABLE test(value);)]
360 }
361 }
362
363 enum GoodDB {}
364
365 impl Domain for GoodDB {
366 fn name() -> &'static str {
367 "db_tests" //Notice same name
368 }
369
370 fn migrations() -> &'static [&'static str] {
371 &[sql!(CREATE TABLE test2(value);)] //But different migration
372 }
373 }
374
375 let tempdir = TempDir::new("DbTests").unwrap();
376 {
377 // Setup the bad database
378 let corrupt_db =
379 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
380 assert!(corrupt_db.persistent());
381 }
382
383 // Try to connect to it a bunch of times at once
384 let mut guards = vec![];
385 for _ in 0..10 {
386 let tmp_path = tempdir.path().to_path_buf();
387 let guard = thread::spawn(move || {
388 let good_db = smol::block_on(open_db::<GoodDB>(
389 tmp_path.as_path(),
390 &util::channel::ReleaseChannel::Dev,
391 ));
392 assert!(
393 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
394 .unwrap()
395 .is_none()
396 );
397 });
398
399 guards.push(guard);
400 }
401
402 for guard in guards.into_iter() {
403 assert!(guard.join().is_ok());
404 }
405 }
406}