1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context;
7use gpui::AppContext;
8pub use indoc::indoc;
9pub use lazy_static;
10use parking_lot::{Mutex, RwLock};
11pub use smol;
12pub use sqlez;
13pub use sqlez_macros;
14pub use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
15pub use util::paths::DB_DIR;
16
17use sqlez::domain::Migrator;
18use sqlez::thread_safe_connection::ThreadSafeConnection;
19use sqlez_macros::sql;
20use std::fs::create_dir_all;
21use std::future::Future;
22use std::path::{Path, PathBuf};
23use std::sync::atomic::{AtomicBool, Ordering};
24use std::time::{SystemTime, UNIX_EPOCH};
25use util::channel::ReleaseChannel;
26use util::{async_iife, ResultExt};
27
28const CONNECTION_INITIALIZE_QUERY: &'static str = sql!(
29 PRAGMA foreign_keys=TRUE;
30);
31
32const DB_INITIALIZE_QUERY: &'static str = sql!(
33 PRAGMA journal_mode=WAL;
34 PRAGMA busy_timeout=1;
35 PRAGMA case_sensitive_like=TRUE;
36 PRAGMA synchronous=NORMAL;
37);
38
39const FALLBACK_DB_NAME: &'static str = "FALLBACK_MEMORY_DB";
40
41const DB_FILE_NAME: &'static str = "db.sqlite";
42
43lazy_static::lazy_static! {
44 pub static ref ZED_STATELESS: bool = std::env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty());
45 pub static ref BACKUP_DB_PATH: RwLock<Option<PathBuf>> = RwLock::new(None);
46 pub static ref ALL_FILE_DB_FAILED: AtomicBool = AtomicBool::new(false);
47}
48static DB_FILE_OPERATIONS: Mutex<()> = Mutex::new(());
49
50/// Open or create a database at the given directory path.
51/// This will retry a couple times if there are failures. If opening fails once, the db directory
52/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
53/// In either case, static variables are set so that the user can be notified.
54pub async fn open_db<M: Migrator + 'static>(
55 db_dir: &Path,
56 release_channel: &ReleaseChannel,
57) -> ThreadSafeConnection<M> {
58 if *ZED_STATELESS {
59 return open_fallback_db().await;
60 }
61
62 let release_channel_name = release_channel.dev_name();
63 let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
64
65 let connection = async_iife!({
66 // Note: This still has a race condition where 1 set of migrations succeeds
67 // (e.g. (Workspace, Editor)) and another fails (e.g. (Workspace, Terminal))
68 // This will cause the first connection to have the database taken out
69 // from under it. This *should* be fine though. The second dabatase failure will
70 // cause errors in the log and so should be observed by developers while writing
71 // soon-to-be good migrations. If user databases are corrupted, we toss them out
72 // and try again from a blank. As long as running all migrations from start to end
73 // on a blank database is ok, this race condition will never be triggered.
74 //
75 // Basically: Don't ever push invalid migrations to stable or everyone will have
76 // a bad time.
77
78 // If no db folder, create one at 0-{channel}
79 create_dir_all(&main_db_dir).context("Could not create db directory")?;
80 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
81
82 // Optimistically open databases in parallel
83 if !DB_FILE_OPERATIONS.is_locked() {
84 // Try building a connection
85 if let Some(connection) = open_main_db(&db_path).await {
86 return Ok(connection)
87 };
88 }
89
90 // Take a lock in the failure case so that we move the db once per process instead
91 // of potentially multiple times from different threads. This shouldn't happen in the
92 // normal path
93 let _lock = DB_FILE_OPERATIONS.lock();
94 if let Some(connection) = open_main_db(&db_path).await {
95 return Ok(connection)
96 };
97
98 let backup_timestamp = SystemTime::now()
99 .duration_since(UNIX_EPOCH)
100 .expect("System clock is set before the unix timestamp, Zed does not support this region of spacetime")
101 .as_millis();
102
103 // If failed, move 0-{channel} to {current unix timestamp}-{channel}
104 let backup_db_dir = db_dir.join(Path::new(&format!(
105 "{}-{}",
106 backup_timestamp,
107 release_channel_name,
108 )));
109
110 std::fs::rename(&main_db_dir, &backup_db_dir)
111 .context("Failed clean up corrupted database, panicking.")?;
112
113 // Set a static ref with the failed timestamp and error so we can notify the user
114 {
115 let mut guard = BACKUP_DB_PATH.write();
116 *guard = Some(backup_db_dir);
117 }
118
119 // Create a new 0-{channel}
120 create_dir_all(&main_db_dir).context("Should be able to create the database directory")?;
121 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
122
123 // Try again
124 open_main_db(&db_path).await.context("Could not newly created db")
125 }).await.log_err();
126
127 if let Some(connection) = connection {
128 return connection;
129 }
130
131 // Set another static ref so that we can escalate the notification
132 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
133
134 // If still failed, create an in memory db with a known name
135 open_fallback_db().await
136}
137
138async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
139 log::info!("Opening main db");
140 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
141 .with_db_initialization_query(DB_INITIALIZE_QUERY)
142 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
143 .build()
144 .await
145 .log_err()
146}
147
148async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
149 log::info!("Opening fallback db");
150 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
151 .with_db_initialization_query(DB_INITIALIZE_QUERY)
152 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
153 .build()
154 .await
155 .expect(
156 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
157 )
158}
159
160#[cfg(any(test, feature = "test-support"))]
161pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
162 use sqlez::thread_safe_connection::locking_queue;
163
164 ThreadSafeConnection::<M>::builder(db_name, false)
165 .with_db_initialization_query(DB_INITIALIZE_QUERY)
166 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
167 // Serialize queued writes via a mutex and run them synchronously
168 .with_write_queue_constructor(locking_queue())
169 .build()
170 .await
171 .unwrap()
172}
173
174/// Implements a basic DB wrapper for a given domain
175#[macro_export]
176macro_rules! define_connection {
177 (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
178 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
179
180 impl ::std::ops::Deref for $t {
181 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
182
183 fn deref(&self) -> &Self::Target {
184 &self.0
185 }
186 }
187
188 impl $crate::sqlez::domain::Domain for $t {
189 fn name() -> &'static str {
190 stringify!($t)
191 }
192
193 fn migrations() -> &'static [&'static str] {
194 $migrations
195 }
196 }
197
198 #[cfg(any(test, feature = "test-support"))]
199 $crate::lazy_static::lazy_static! {
200 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
201 }
202
203 #[cfg(not(any(test, feature = "test-support")))]
204 $crate::lazy_static::lazy_static! {
205 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
206 }
207 };
208 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
209 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
210
211 impl ::std::ops::Deref for $t {
212 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
213
214 fn deref(&self) -> &Self::Target {
215 &self.0
216 }
217 }
218
219 impl $crate::sqlez::domain::Domain for $t {
220 fn name() -> &'static str {
221 stringify!($t)
222 }
223
224 fn migrations() -> &'static [&'static str] {
225 $migrations
226 }
227 }
228
229 #[cfg(any(test, feature = "test-support"))]
230 $crate::lazy_static::lazy_static! {
231 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_test_db(stringify!($id))));
232 }
233
234 #[cfg(not(any(test, feature = "test-support")))]
235 $crate::lazy_static::lazy_static! {
236 pub static ref $id: $t = $t($crate::smol::block_on($crate::open_db(&$crate::DB_DIR, &$crate::RELEASE_CHANNEL)));
237 }
238 };
239}
240
241pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
242where
243 F: Future<Output = anyhow::Result<()>> + Send,
244{
245 cx.background()
246 .spawn(async move { db_write().await.log_err() })
247 .detach()
248}
249
250#[cfg(test)]
251mod tests {
252 use std::{fs, thread};
253
254 use sqlez::{connection::Connection, domain::Domain};
255 use sqlez_macros::sql;
256 use tempdir::TempDir;
257
258 use crate::{open_db, DB_FILE_NAME};
259
260 // Test bad migration panics
261 #[gpui::test]
262 #[should_panic]
263 async fn test_bad_migration_panics() {
264 enum BadDB {}
265
266 impl Domain for BadDB {
267 fn name() -> &'static str {
268 "db_tests"
269 }
270
271 fn migrations() -> &'static [&'static str] {
272 &[
273 sql!(CREATE TABLE test(value);),
274 // failure because test already exists
275 sql!(CREATE TABLE test(value);),
276 ]
277 }
278 }
279
280 let tempdir = TempDir::new("DbTests").unwrap();
281 let _bad_db = open_db::<BadDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
282 }
283
284 /// Test that DB exists but corrupted (causing recreate)
285 #[gpui::test]
286 async fn test_db_corruption() {
287 enum CorruptedDB {}
288
289 impl Domain for CorruptedDB {
290 fn name() -> &'static str {
291 "db_tests"
292 }
293
294 fn migrations() -> &'static [&'static str] {
295 &[sql!(CREATE TABLE test(value);)]
296 }
297 }
298
299 enum GoodDB {}
300
301 impl Domain for GoodDB {
302 fn name() -> &'static str {
303 "db_tests" //Notice same name
304 }
305
306 fn migrations() -> &'static [&'static str] {
307 &[sql!(CREATE TABLE test2(value);)] //But different migration
308 }
309 }
310
311 let tempdir = TempDir::new("DbTests").unwrap();
312 {
313 let corrupt_db =
314 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
315 assert!(corrupt_db.persistent());
316 }
317
318 let good_db = open_db::<GoodDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
319 assert!(
320 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
321 .unwrap()
322 .is_none()
323 );
324
325 let mut corrupted_backup_dir = fs::read_dir(tempdir.path())
326 .unwrap()
327 .find(|entry| {
328 !entry
329 .as_ref()
330 .unwrap()
331 .file_name()
332 .to_str()
333 .unwrap()
334 .starts_with("0")
335 })
336 .unwrap()
337 .unwrap()
338 .path();
339 corrupted_backup_dir.push(DB_FILE_NAME);
340
341 let backup = Connection::open_file(&corrupted_backup_dir.to_string_lossy());
342 assert!(backup.select_row::<usize>("SELECT * FROM test").unwrap()()
343 .unwrap()
344 .is_none());
345 }
346
347 /// Test that DB exists but corrupted (causing recreate)
348 #[gpui::test]
349 async fn test_simultaneous_db_corruption() {
350 enum CorruptedDB {}
351
352 impl Domain for CorruptedDB {
353 fn name() -> &'static str {
354 "db_tests"
355 }
356
357 fn migrations() -> &'static [&'static str] {
358 &[sql!(CREATE TABLE test(value);)]
359 }
360 }
361
362 enum GoodDB {}
363
364 impl Domain for GoodDB {
365 fn name() -> &'static str {
366 "db_tests" //Notice same name
367 }
368
369 fn migrations() -> &'static [&'static str] {
370 &[sql!(CREATE TABLE test2(value);)] //But different migration
371 }
372 }
373
374 let tempdir = TempDir::new("DbTests").unwrap();
375 {
376 // Setup the bad database
377 let corrupt_db =
378 open_db::<CorruptedDB>(tempdir.path(), &util::channel::ReleaseChannel::Dev).await;
379 assert!(corrupt_db.persistent());
380 }
381
382 // Try to connect to it a bunch of times at once
383 let mut guards = vec![];
384 for _ in 0..10 {
385 let tmp_path = tempdir.path().to_path_buf();
386 let guard = thread::spawn(move || {
387 let good_db = smol::block_on(open_db::<GoodDB>(
388 tmp_path.as_path(),
389 &util::channel::ReleaseChannel::Dev,
390 ));
391 assert!(
392 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
393 .unwrap()
394 .is_none()
395 );
396 });
397
398 guards.push(guard);
399 }
400
401 for guard in guards.into_iter() {
402 assert!(guard.join().is_ok());
403 }
404 }
405}