1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context as _;
7use gpui::{App, AppContext};
8pub use indoc::indoc;
9pub use paths::database_dir;
10pub use smol;
11pub use sqlez;
12pub use sqlez_macros;
13
14pub use release_channel::RELEASE_CHANNEL;
15use sqlez::domain::Migrator;
16use sqlez::thread_safe_connection::ThreadSafeConnection;
17use sqlez_macros::sql;
18use std::future::Future;
19use std::path::Path;
20use std::sync::{LazyLock, atomic::Ordering};
21use std::{env, sync::atomic::AtomicBool};
22use util::{ResultExt, maybe};
23
24const CONNECTION_INITIALIZE_QUERY: &str = sql!(
25 PRAGMA foreign_keys=TRUE;
26);
27
28const DB_INITIALIZE_QUERY: &str = sql!(
29 PRAGMA journal_mode=WAL;
30 PRAGMA busy_timeout=1;
31 PRAGMA case_sensitive_like=TRUE;
32 PRAGMA synchronous=NORMAL;
33);
34
35const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
36
37const DB_FILE_NAME: &str = "db.sqlite";
38
39pub static ZED_STATELESS: LazyLock<bool> =
40 LazyLock::new(|| env::var("ZED_STATELESS").is_ok_and(|v| !v.is_empty()));
41
42pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
43
44/// Open or create a database at the given directory path.
45/// This will retry a couple times if there are failures. If opening fails once, the db directory
46/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
47/// In either case, static variables are set so that the user can be notified.
48pub async fn open_db<M: Migrator + 'static>(db_dir: &Path, scope: &str) -> ThreadSafeConnection {
49 if *ZED_STATELESS {
50 return open_fallback_db::<M>().await;
51 }
52
53 let main_db_dir = db_dir.join(format!("0-{}", scope));
54
55 let connection = maybe!(async {
56 smol::fs::create_dir_all(&main_db_dir)
57 .await
58 .context("Could not create db directory")
59 .log_err()?;
60 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
61 open_main_db::<M>(&db_path).await
62 })
63 .await;
64
65 if let Some(connection) = connection {
66 return connection;
67 }
68
69 // Set another static ref so that we can escalate the notification
70 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
71
72 // If still failed, create an in memory db with a known name
73 open_fallback_db::<M>().await
74}
75
76async fn open_main_db<M: Migrator>(db_path: &Path) -> Option<ThreadSafeConnection> {
77 log::trace!("Opening database {}", db_path.display());
78 ThreadSafeConnection::builder::<M>(db_path.to_string_lossy().as_ref(), true)
79 .with_db_initialization_query(DB_INITIALIZE_QUERY)
80 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
81 .build()
82 .await
83 .log_err()
84}
85
86async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection {
87 log::warn!("Opening fallback in-memory database");
88 ThreadSafeConnection::builder::<M>(FALLBACK_DB_NAME, false)
89 .with_db_initialization_query(DB_INITIALIZE_QUERY)
90 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
91 .build()
92 .await
93 .expect(
94 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
95 )
96}
97
98#[cfg(any(test, feature = "test-support"))]
99pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection {
100 use sqlez::thread_safe_connection::locking_queue;
101
102 ThreadSafeConnection::builder::<M>(db_name, false)
103 .with_db_initialization_query(DB_INITIALIZE_QUERY)
104 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
105 // Serialize queued writes via a mutex and run them synchronously
106 .with_write_queue_constructor(locking_queue())
107 .build()
108 .await
109 .unwrap()
110}
111
112/// Implements a basic DB wrapper for a given domain
113///
114/// Arguments:
115/// - static variable name for connection
116/// - type of connection wrapper
117/// - dependencies, whose migrations should be run prior to this domain's migrations
118#[macro_export]
119macro_rules! static_connection {
120 ($id:ident, $t:ident, [ $($d:ty),* ] $(, $global:ident)?) => {
121 impl ::std::ops::Deref for $t {
122 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection;
123
124 fn deref(&self) -> &Self::Target {
125 &self.0
126 }
127 }
128
129 impl $t {
130 #[cfg(any(test, feature = "test-support"))]
131 pub async fn open_test_db(name: &'static str) -> Self {
132 $t($crate::open_test_db::<$t>(name).await)
133 }
134 }
135
136 #[cfg(any(test, feature = "test-support"))]
137 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
138 #[allow(unused_parens)]
139 $t($crate::smol::block_on($crate::open_test_db::<($($d,)* $t)>(stringify!($id))))
140 });
141
142 #[cfg(not(any(test, feature = "test-support")))]
143 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
144 let db_dir = $crate::database_dir();
145 let scope = if false $(|| stringify!($global) == "global")? {
146 "global"
147 } else {
148 $crate::RELEASE_CHANNEL.dev_name()
149 };
150 #[allow(unused_parens)]
151 $t($crate::smol::block_on($crate::open_db::<($($d,)* $t)>(db_dir, scope)))
152 });
153 }
154}
155
156pub fn write_and_log<F>(cx: &App, db_write: impl FnOnce() -> F + Send + 'static)
157where
158 F: Future<Output = anyhow::Result<()>> + Send,
159{
160 cx.background_spawn(async move { db_write().await.log_err() })
161 .detach()
162}
163
164#[cfg(test)]
165mod tests {
166 use std::thread;
167
168 use sqlez::domain::Domain;
169 use sqlez_macros::sql;
170
171 use crate::open_db;
172
173 // Test bad migration panics
174 #[gpui::test]
175 #[should_panic]
176 async fn test_bad_migration_panics() {
177 enum BadDB {}
178
179 impl Domain for BadDB {
180 const NAME: &str = "db_tests";
181 const MIGRATIONS: &[&str] = &[
182 sql!(CREATE TABLE test(value);),
183 // failure because test already exists
184 sql!(CREATE TABLE test(value);),
185 ];
186 }
187
188 let tempdir = tempfile::Builder::new()
189 .prefix("DbTests")
190 .tempdir()
191 .unwrap();
192 let _bad_db = open_db::<BadDB>(
193 tempdir.path(),
194 release_channel::ReleaseChannel::Dev.dev_name(),
195 )
196 .await;
197 }
198
199 /// Test that DB exists but corrupted (causing recreate)
200 #[gpui::test]
201 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
202 cx.executor().allow_parking();
203
204 enum CorruptedDB {}
205
206 impl Domain for CorruptedDB {
207 const NAME: &str = "db_tests";
208 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
209 }
210
211 enum GoodDB {}
212
213 impl Domain for GoodDB {
214 const NAME: &str = "db_tests"; //Notice same name
215 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)];
216 }
217
218 let tempdir = tempfile::Builder::new()
219 .prefix("DbTests")
220 .tempdir()
221 .unwrap();
222 {
223 let corrupt_db = open_db::<CorruptedDB>(
224 tempdir.path(),
225 release_channel::ReleaseChannel::Dev.dev_name(),
226 )
227 .await;
228 assert!(corrupt_db.persistent());
229 }
230
231 let good_db = open_db::<GoodDB>(
232 tempdir.path(),
233 release_channel::ReleaseChannel::Dev.dev_name(),
234 )
235 .await;
236 assert!(
237 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
238 .unwrap()
239 .is_none()
240 );
241 }
242
243 /// Test that DB exists but corrupted (causing recreate)
244 #[gpui::test(iterations = 30)]
245 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
246 cx.executor().allow_parking();
247
248 enum CorruptedDB {}
249
250 impl Domain for CorruptedDB {
251 const NAME: &str = "db_tests";
252
253 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test(value);)];
254 }
255
256 enum GoodDB {}
257
258 impl Domain for GoodDB {
259 const NAME: &str = "db_tests"; //Notice same name
260 const MIGRATIONS: &[&str] = &[sql!(CREATE TABLE test2(value);)]; // But different migration
261 }
262
263 let tempdir = tempfile::Builder::new()
264 .prefix("DbTests")
265 .tempdir()
266 .unwrap();
267 {
268 // Setup the bad database
269 let corrupt_db = open_db::<CorruptedDB>(
270 tempdir.path(),
271 release_channel::ReleaseChannel::Dev.dev_name(),
272 )
273 .await;
274 assert!(corrupt_db.persistent());
275 }
276
277 // Try to connect to it a bunch of times at once
278 let mut guards = vec![];
279 for _ in 0..10 {
280 let tmp_path = tempdir.path().to_path_buf();
281 let guard = thread::spawn(move || {
282 let good_db = smol::block_on(open_db::<GoodDB>(
283 tmp_path.as_path(),
284 release_channel::ReleaseChannel::Dev.dev_name(),
285 ));
286 assert!(
287 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
288 .unwrap()
289 .is_none()
290 );
291 });
292
293 guards.push(guard);
294 }
295
296 for guard in guards.into_iter() {
297 assert!(guard.join().is_ok());
298 }
299 }
300}