1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context as _;
7use gpui::{App, AppContext};
8pub use indoc::indoc;
9pub use paths::database_dir;
10pub use smol;
11pub use sqlez;
12pub use sqlez_macros;
13
14pub use release_channel::RELEASE_CHANNEL;
15use sqlez::domain::Migrator;
16use sqlez::thread_safe_connection::ThreadSafeConnection;
17use sqlez_macros::sql;
18use std::future::Future;
19use std::path::Path;
20use std::sync::atomic::AtomicBool;
21use std::sync::{LazyLock, atomic::Ordering};
22use util::{ResultExt, maybe};
23use zed_env_vars::ZED_STATELESS;
24
25const CONNECTION_INITIALIZE_QUERY: &str = sql!(
26 PRAGMA foreign_keys=TRUE;
27);
28
29const DB_INITIALIZE_QUERY: &str = sql!(
30 PRAGMA journal_mode=WAL;
31 PRAGMA busy_timeout=1;
32 PRAGMA case_sensitive_like=TRUE;
33 PRAGMA synchronous=NORMAL;
34);
35
36const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
37
38const DB_FILE_NAME: &str = "db.sqlite";
39
40pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
41
42/// Open or create a database at the given directory path.
43/// This will retry a couple times if there are failures. If opening fails once, the db directory
44/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
45/// In either case, static variables are set so that the user can be notified.
46pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
47 if *ZED_STATELESS {
48 return open_fallback_db::<M>().await;
49 }
50
51 let main_db_dir = db_dir.join(format!("0-{}", scope));
52
53 let connection = maybe!(async {
54 smol::fs::create_dir_all(&main_db_dir)
55 .await
56 .context("Could not create db directory")
57 .log_err()?;
58 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
59 open_main_db::<M>(&db_path).await
60 })
61 .await;
62
63 if let Some(connection) = connection {
64 return connection;
65 }
66
67 // Set another static ref so that we can escalate the notification
68 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
69
70 // If still failed, create an in memory db with a known name
71 open_fallback_db::<M>().await
72}
73
74async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
75 log::trace!("Opening database {}", db_path.display());
76 ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
77 .with_db_initialization_query(DB_INITIALIZE_QUERY)
78 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
79 .build()
80 .await
81 .log_err()
82}
83
84async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
85 log::warn!("Opening fallback in-memory database");
86 ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
87 .with_db_initialization_query(DB_INITIALIZE_QUERY)
88 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
89 .build()
90 .await
91 .expect(
92 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
93 )
94}
95
96#[cfg(any(test, feature = "test-support"))]
97pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
98 use sqlez::thread_safe_connection::locking_queue;
99
100 ThreadSafeConnection::builder::<M>(db_name, false)
101 .with_db_initialization_query(DB_INITIALIZE_QUERY)
102 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
103 // Serialize queued writes via a mutex and run them synchronously
104 .with_write_queue_constructor(locking_queue())
105 .build()
106 .await
107 .unwrap()
108}
109
110/// Implements a basic DB wrapper for a given domain
111///
112/// Arguments:
113/// - static variable name for connection
114/// - type of connection wrapper
115/// - dependencies, whose migrations should be run prior to this domain's migrations
116#[macro_export]
117macro_rules! static_connection {
118 ($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
119 impl ::std::ops::Deref for $t {
120 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
121
122 fn deref(&self) -> &Self::Target {
123 &self.0
124 }
125 }
126
127 impl $t {
128 #[cfg(any(test, feature = "test-support"))]
129 pub async fn open_test_db(name: &'static str) -> Self {
130 $t($crate::open_test_db::<$t>(name).await)
131 }
132 }
133
134 #[cfg(any(test, feature = "test-support"))]
135 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
136 #[allow(unused_parens)]
137 $t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
138 });
139
140 #[cfg(not(any(test, feature = "test-support")))]
141 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
142 let db_dir = $crate::database_dir();
143 let scope = if false $(|| stringify!($global) == "global")? {
144 "global"
145 } else {
146 $crate::RELEASE_CHANNEL.dev_name()
147 };
148 #[allow(unused_parens)]
149 $t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
150 });
151 }
152}
153
154pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
155where
156 F: Future<Output = anyhow::Result<()>> + Send,
157{
158 cx.background_spawn(async move { db_write().await.log_err() })
159 .detach()
160}
161
162#[cfg(test)]
163mod tests {
164 use std::thread;
165
166 use sqlez::domain::Domain;
167 use sqlez_macros::sql;
168
169 use crate::open_db;
170
171 // Test bad migration panics
172 #[gpui::test]
173 #[should_panic]
174 async fn test_bad_migration_panics() {
175 enum BadDB {}
176
177 impl Domain for BadDB {
178 const NAME: &str = "db_tests";
179 const MIGRATIONS: &[&str] = &[
180 sql!(CREATE TABLE test(value);),
181 // failure because test already exists
182 sql!(CREATE TABLE test(value);),
183 ];
184 }
185
186 let tempdir = tempfile::Builder::new()
187 .prefix("DbTests")
188 .tempdir()
189 .unwrap();
190 let _bad_db = open_db::<BadDB>(
191 tempdir.path(),
192 release_channel::ReleaseChannel::Dev.dev_name(),
193 )
194 .await;
195 }
196
197 /// Test that DB exists but corrupted (causing recreate)
198 #[gpui::test]
199 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
200 cx.executor().allow_parking();
201
202 enum CorruptedDB {}
203
204 impl Domain for CorruptedDB {
205 const NAME: &str = "db_tests";
206 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
207 }
208
209 enum GoodDB {}
210
211 impl Domain for GoodDB {
212 const NAME: &str = "db_tests"; //Notice same name
213 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
214 }
215
216 let tempdir = tempfile::Builder::new()
217 .prefix("DbTests")
218 .tempdir()
219 .unwrap();
220 {
221 let corrupt_db = open_db::<CorruptedDB>(
222 tempdir.path(),
223 release_channel::ReleaseChannel::Dev.dev_name(),
224 )
225 .await;
226 assert!(corrupt_db.persistent());
227 }
228
229 let good_db = open_db::<GoodDB>(
230 tempdir.path(),
231 release_channel::ReleaseChannel::Dev.dev_name(),
232 )
233 .await;
234 assert!(
235 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
236 .unwrap()
237 .is_none()
238 );
239 }
240
241 /// Test that DB exists but corrupted (causing recreate)
242 #[gpui::test(iterations = 30)]
243 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
244 cx.executor().allow_parking();
245
246 enum CorruptedDB {}
247
248 impl Domain for CorruptedDB {
249 const NAME: &str = "db_tests";
250
251 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
252 }
253
254 enum GoodDB {}
255
256 impl Domain for GoodDB {
257 const NAME: &str = "db_tests"; //Notice same name
258 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
259 }
260
261 let tempdir = tempfile::Builder::new()
262 .prefix("DbTests")
263 .tempdir()
264 .unwrap();
265 {
266 // Setup the bad database
267 let corrupt_db = open_db::<CorruptedDB>(
268 tempdir.path(),
269 release_channel::ReleaseChannel::Dev.dev_name(),
270 )
271 .await;
272 assert!(corrupt_db.persistent());
273 }
274
275 // Try to connect to it a bunch of times at once
276 let mut guards = vec![];
277 for _ in 0..10 {
278 let tmp_path = tempdir.path().to_path_buf();
279 let guard = thread::spawn(move || {
280 let good_db = smol::block_on(open_db::<GoodDB>(
281 tmp_path.as_path(),
282 release_channel::ReleaseChannel::Dev.dev_name(),
283 ));
284 assert!(
285 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
286 .unwrap()
287 .is_none()
288 );
289 });
290
291 guards.push(guard);
292 }
293
294 for guard in guards.into_iter() {
295 assert!(guard.join().is_ok());
296 }
297 }
298}