1pub mod kvp;
2pub mod query;
3
4// Re-export
5pub use anyhow;
6use anyhow::Context;
7use gpui::AppContext;
8pub use indoc::indoc;
9pub use paths::database_dir;
10pub use smol;
11pub use sqlez;
12pub use sqlez_macros;
13
14use release_channel::ReleaseChannel;
15pub use release_channel::RELEASE_CHANNEL;
16use sqlez::domain::Migrator;
17use sqlez::thread_safe_connection::ThreadSafeConnection;
18use sqlez_macros::sql;
19use std::env;
20use std::future::Future;
21use std::path::{Path, PathBuf};
22use std::sync::atomic::{AtomicBool, Ordering};
23use std::sync::LazyLock;
24use util::{maybe, ResultExt};
25
26const CONNECTION_INITIALIZE_QUERY: &str = sql!(
27 PRAGMA foreign_keys=TRUE;
28);
29
30const DB_INITIALIZE_QUERY: &str = sql!(
31 PRAGMA journal_mode=WAL;
32 PRAGMA busy_timeout=1;
33 PRAGMA case_sensitive_like=TRUE;
34 PRAGMA synchronous=NORMAL;
35);
36
37const FALLBACK_DB_NAME: &str = "FALLBACK_MEMORY_DB";
38
39const DB_FILE_NAME: &str = "db.sqlite";
40
41pub static ZED_STATELESS: LazyLock<bool> =
42 LazyLock::new(|| env::var("ZED_STATELESS").map_or(false, |v| !v.is_empty()));
43
44pub static ALL_FILE_DB_FAILED: LazyLock<AtomicBool> = LazyLock::new(|| AtomicBool::new(false));
45
46/// Open or create a database at the given directory path.
47/// This will retry a couple times if there are failures. If opening fails once, the db directory
48/// is moved to a backup folder and a new one is created. If that fails, a shared in memory db is created.
49/// In either case, static variables are set so that the user can be notified.
50pub async fn open_db<M: Migrator + 'static>(
51 db_dir: &Path,
52 release_channel: &ReleaseChannel,
53) -> ThreadSafeConnection<M> {
54 if *ZED_STATELESS {
55 return open_fallback_db().await;
56 }
57
58 let release_channel_name = release_channel.dev_name();
59 let main_db_dir = db_dir.join(Path::new(&format!("0-{}", release_channel_name)));
60
61 let connection = maybe!(async {
62 smol::fs::create_dir_all(&main_db_dir)
63 .await
64 .context("Could not create db directory")
65 .log_err()?;
66 let db_path = main_db_dir.join(Path::new(DB_FILE_NAME));
67 open_main_db(&db_path).await
68 })
69 .await;
70
71 if let Some(connection) = connection {
72 return connection;
73 }
74
75 // Set another static ref so that we can escalate the notification
76 ALL_FILE_DB_FAILED.store(true, Ordering::Release);
77
78 // If still failed, create an in memory db with a known name
79 open_fallback_db().await
80}
81
82async fn open_main_db<M: Migrator>(db_path: &PathBuf) -> Option<ThreadSafeConnection<M>> {
83 log::info!("Opening main db");
84 ThreadSafeConnection::<M>::builder(db_path.to_string_lossy().as_ref(), true)
85 .with_db_initialization_query(DB_INITIALIZE_QUERY)
86 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
87 .build()
88 .await
89 .log_err()
90}
91
92async fn open_fallback_db<M: Migrator>() -> ThreadSafeConnection<M> {
93 log::info!("Opening fallback db");
94 ThreadSafeConnection::<M>::builder(FALLBACK_DB_NAME, false)
95 .with_db_initialization_query(DB_INITIALIZE_QUERY)
96 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
97 .build()
98 .await
99 .expect(
100 "Fallback in memory database failed. Likely initialization queries or migrations have fundamental errors",
101 )
102}
103
104#[cfg(any(test, feature = "test-support"))]
105pub async fn open_test_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<M> {
106 use sqlez::thread_safe_connection::locking_queue;
107
108 ThreadSafeConnection::<M>::builder(db_name, false)
109 .with_db_initialization_query(DB_INITIALIZE_QUERY)
110 .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
111 // Serialize queued writes via a mutex and run them synchronously
112 .with_write_queue_constructor(locking_queue())
113 .build()
114 .await
115 .unwrap()
116}
117
118/// Implements a basic DB wrapper for a given domain
119#[macro_export]
120macro_rules! define_connection {
121 (pub static ref $id:ident: $t:ident<()> = $migrations:expr;) => {
122 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>);
123
124 impl ::std::ops::Deref for $t {
125 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<$t>;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130 }
131
132 impl $crate::sqlez::domain::Domain for $t {
133 fn name() -> &'static str {
134 stringify!($t)
135 }
136
137 fn migrations() -> &'static [&'static str] {
138 $migrations
139 }
140 }
141
142 use std::sync::LazyLock;
143 #[cfg(any(test, feature = "test-support"))]
144 pub static $id: LazyLock<$t> = LazyLock::new(|| {
145 $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
146 });
147
148 #[cfg(not(any(test, feature = "test-support")))]
149 pub static $id: LazyLock<$t> = LazyLock::new(|| {
150 $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)))
151 });
152 };
153 (pub static ref $id:ident: $t:ident<$($d:ty),+> = $migrations:expr;) => {
154 pub struct $t($crate::sqlez::thread_safe_connection::ThreadSafeConnection<( $($d),+, $t )>);
155
156 impl ::std::ops::Deref for $t {
157 type Target = $crate::sqlez::thread_safe_connection::ThreadSafeConnection<($($d),+, $t)>;
158
159 fn deref(&self) -> &Self::Target {
160 &self.0
161 }
162 }
163
164 impl $crate::sqlez::domain::Domain for $t {
165 fn name() -> &'static str {
166 stringify!($t)
167 }
168
169 fn migrations() -> &'static [&'static str] {
170 $migrations
171 }
172 }
173
174 #[cfg(any(test, feature = "test-support"))]
175 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
176 $t($crate::smol::block_on($crate::open_test_db(stringify!($id))))
177 });
178
179 #[cfg(not(any(test, feature = "test-support")))]
180 pub static $id: std::sync::LazyLock<$t> = std::sync::LazyLock::new(|| {
181 $t($crate::smol::block_on($crate::open_db($crate::database_dir(), &$crate::RELEASE_CHANNEL)))
182 });
183 };
184}
185
186pub fn write_and_log<F>(cx: &mut AppContext, db_write: impl FnOnce() -> F + Send + 'static)
187where
188 F: Future<Output = anyhow::Result<()>> + Send,
189{
190 cx.background_executor()
191 .spawn(async move { db_write().await.log_err() })
192 .detach()
193}
194
195#[cfg(test)]
196mod tests {
197 use std::thread;
198
199 use sqlez::domain::Domain;
200 use sqlez_macros::sql;
201
202 use crate::open_db;
203
204 // Test bad migration panics
205 #[gpui::test]
206 #[should_panic]
207 async fn test_bad_migration_panics() {
208 enum BadDB {}
209
210 impl Domain for BadDB {
211 fn name() -> &'static str {
212 "db_tests"
213 }
214
215 fn migrations() -> &'static [&'static str] {
216 &[
217 sql!(CREATE TABLE test(value);),
218 // failure because test already exists
219 sql!(CREATE TABLE test(value);),
220 ]
221 }
222 }
223
224 let tempdir = tempfile::Builder::new()
225 .prefix("DbTests")
226 .tempdir()
227 .unwrap();
228 let _bad_db = open_db::<BadDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
229 }
230
231 /// Test that DB exists but corrupted (causing recreate)
232 #[gpui::test]
233 async fn test_db_corruption(cx: &mut gpui::TestAppContext) {
234 cx.executor().allow_parking();
235
236 enum CorruptedDB {}
237
238 impl Domain for CorruptedDB {
239 fn name() -> &'static str {
240 "db_tests"
241 }
242
243 fn migrations() -> &'static [&'static str] {
244 &[sql!(CREATE TABLE test(value);)]
245 }
246 }
247
248 enum GoodDB {}
249
250 impl Domain for GoodDB {
251 fn name() -> &'static str {
252 "db_tests" //Notice same name
253 }
254
255 fn migrations() -> &'static [&'static str] {
256 &[sql!(CREATE TABLE test2(value);)] //But different migration
257 }
258 }
259
260 let tempdir = tempfile::Builder::new()
261 .prefix("DbTests")
262 .tempdir()
263 .unwrap();
264 {
265 let corrupt_db =
266 open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
267 assert!(corrupt_db.persistent());
268 }
269
270 let good_db =
271 open_db::<GoodDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
272 assert!(
273 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
274 .unwrap()
275 .is_none()
276 );
277 }
278
279 /// Test that DB exists but corrupted (causing recreate)
280 #[gpui::test(iterations = 30)]
281 async fn test_simultaneous_db_corruption(cx: &mut gpui::TestAppContext) {
282 cx.executor().allow_parking();
283
284 enum CorruptedDB {}
285
286 impl Domain for CorruptedDB {
287 fn name() -> &'static str {
288 "db_tests"
289 }
290
291 fn migrations() -> &'static [&'static str] {
292 &[sql!(CREATE TABLE test(value);)]
293 }
294 }
295
296 enum GoodDB {}
297
298 impl Domain for GoodDB {
299 fn name() -> &'static str {
300 "db_tests" //Notice same name
301 }
302
303 fn migrations() -> &'static [&'static str] {
304 &[sql!(CREATE TABLE test2(value);)] //But different migration
305 }
306 }
307
308 let tempdir = tempfile::Builder::new()
309 .prefix("DbTests")
310 .tempdir()
311 .unwrap();
312 {
313 // Setup the bad database
314 let corrupt_db =
315 open_db::<CorruptedDB>(tempdir.path(), &release_channel::ReleaseChannel::Dev).await;
316 assert!(corrupt_db.persistent());
317 }
318
319 // Try to connect to it a bunch of times at once
320 let mut guards = vec![];
321 for _ in 0..10 {
322 let tmp_path = tempdir.path().to_path_buf();
323 let guard = thread::spawn(move || {
324 let good_db = smol::block_on(open_db::<GoodDB>(
325 tmp_path.as_path(),
326 &release_channel::ReleaseChannel::Dev,
327 ));
328 assert!(
329 good_db.select_row::<usize>("SELECT * FROM test2").unwrap()()
330 .unwrap()
331 .is_none()
332 );
333 });
334
335 guards.push(guard);
336 }
337
338 for guard in guards.into_iter() {
339 assert!(guard.join().is_ok());
340 }
341 }
342}