1use anyhow::Context as _;
2use collections::HashMap;
3use futures::{Future, FutureExt, channel::oneshot};
4use parking_lot::{Mutex, RwLock};
5use std::{
6 marker::PhantomData,
7 ops::Deref,
8 sync::{Arc, LazyLock},
9 thread,
10};
11use thread_local::ThreadLocal;
12
13use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
14
15const MIGRATION_RETRIES: usize = 10;
16
17type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
18type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
19type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
20
21/// List of queues of tasks by database uri. This lets us serialize writes to the database
22/// and have a single worker thread per db file. This means many thread safe connections
23/// (possibly with different migrations) could all be communicating with the same background
24/// thread.
25static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
26
27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
29/// may be accessed by passing a callback to the `write` function which will queue the callback
30#[derive(Clone)]
31pub struct ThreadSafeConnection {
32 uri: Arc<str>,
33 persistent: bool,
34 connection_initialize_query: Option<&'static str>,
35 connections: Arc<ThreadLocal<Connection>>,
36}
37
38unsafe impl Send for ThreadSafeConnection {}
39unsafe impl Sync for ThreadSafeConnection {}
40
41pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
42 db_initialize_query: Option<&'static str>,
43 write_queue_constructor: Option<WriteQueueConstructor>,
44 connection: ThreadSafeConnection,
45 _migrator: PhantomData<*mut M>,
46}
47
48impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
49 /// Sets the query to run every time a connection is opened. This must
50 /// be infallible (EG only use pragma statements) and not cause writes.
51 /// to the db or it will panic.
52 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
53 self.connection.connection_initialize_query = Some(initialize_query);
54 self
55 }
56
57 /// Queues an initialization query for the database file. This must be infallible
58 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
59 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
60 self.db_initialize_query = Some(initialize_query);
61 self
62 }
63
64 /// Specifies how the thread safe connection should serialize writes. If provided
65 /// the connection will call the write_queue_constructor for each database file in
66 /// this process. The constructor is responsible for setting up a background thread or
67 /// async task which handles queued writes with the provided connection.
68 pub fn with_write_queue_constructor(
69 mut self,
70 write_queue_constructor: WriteQueueConstructor,
71 ) -> Self {
72 self.write_queue_constructor = Some(write_queue_constructor);
73 self
74 }
75
76 pub async fn build(self) -> anyhow::Result<ThreadSafeConnection> {
77 self.connection
78 .initialize_queues(self.write_queue_constructor);
79
80 let db_initialize_query = self.db_initialize_query;
81
82 self.connection
83 .write(move |connection| {
84 if let Some(db_initialize_query) = db_initialize_query {
85 connection.exec(db_initialize_query).with_context(|| {
86 format!(
87 "Db initialize query failed to execute: {}",
88 db_initialize_query
89 )
90 })?()?;
91 }
92
93 // Retry failed migrations in case they were run in parallel from different
94 // processes. This gives a best attempt at migrating before bailing
95 let mut migration_result =
96 anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
97
98 let foreign_keys_enabled: bool =
99 connection.select_row::<i32>("PRAGMA foreign_keys")?()
100 .unwrap_or(None)
101 .map(|enabled| enabled != 0)
102 .unwrap_or(false);
103
104 connection.exec("PRAGMA foreign_keys = OFF;")?()?;
105
106 for _ in 0..MIGRATION_RETRIES {
107 migration_result = connection
108 .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
109
110 if migration_result.is_ok() {
111 break;
112 }
113 }
114
115 if foreign_keys_enabled {
116 connection.exec("PRAGMA foreign_keys = ON;")?()?;
117 }
118 migration_result
119 })
120 .await?;
121
122 Ok(self.connection)
123 }
124}
125
126impl ThreadSafeConnection {
127 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
128 if !QUEUES.read().contains_key(&self.uri) {
129 let mut queues = QUEUES.write();
130 if !queues.contains_key(&self.uri) {
131 let mut write_queue_constructor =
132 write_queue_constructor.unwrap_or_else(background_thread_queue);
133 queues.insert(self.uri.clone(), write_queue_constructor());
134 return true;
135 }
136 }
137 false
138 }
139
140 pub fn builder<M: Migrator>(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
141 ThreadSafeConnectionBuilder::<M> {
142 db_initialize_query: None,
143 write_queue_constructor: None,
144 connection: Self {
145 uri: Arc::from(uri),
146 persistent,
147 connection_initialize_query: None,
148 connections: Default::default(),
149 },
150 _migrator: PhantomData,
151 }
152 }
153
154 /// Opens a new db connection with the initialized file path. This is internal and only
155 /// called from the deref function.
156 fn open_file(uri: &str) -> Connection {
157 Connection::open_file(uri)
158 }
159
160 /// Opens a shared memory connection using the file path as the identifier. This is internal
161 /// and only called from the deref function.
162 fn open_shared_memory(uri: &str) -> Connection {
163 Connection::open_memory(Some(uri))
164 }
165
166 pub fn write<T: 'static + Send + Sync>(
167 &self,
168 callback: impl 'static + Send + FnOnce(&Connection) -> T,
169 ) -> impl Future<Output = T> {
170 // Check and invalidate queue and maybe recreate queue
171 let queues = QUEUES.read();
172 let write_channel = queues
173 .get(&self.uri)
174 .expect("Queues are inserted when build is called. This should always succeed");
175
176 // Create a one shot channel for the result of the queued write
177 // so we can await on the result
178 let (sender, receiver) = oneshot::channel();
179
180 let thread_safe_connection = (*self).clone();
181 write_channel(Box::new(move || {
182 let connection = thread_safe_connection.deref();
183 let result = connection.with_write(|connection| callback(connection));
184 sender.send(result).ok();
185 }));
186 receiver.map(|response| response.expect("Write queue unexpectedly closed"))
187 }
188
189 pub(crate) fn create_connection(
190 persistent: bool,
191 uri: &str,
192 connection_initialize_query: Option<&'static str>,
193 ) -> Connection {
194 let mut connection = if persistent {
195 Self::open_file(uri)
196 } else {
197 Self::open_shared_memory(uri)
198 };
199
200 // Disallow writes on the connection. The only writes allowed for thread safe connections
201 // are from the background thread that can serialize them.
202 *connection.write.get_mut() = false;
203
204 if let Some(initialize_query) = connection_initialize_query {
205 connection.exec(initialize_query).unwrap_or_else(|_| {
206 panic!("Initialize query failed to execute: {}", initialize_query)
207 })()
208 .unwrap()
209 }
210
211 connection
212 }
213}
214
215impl ThreadSafeConnection {
216 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
217 /// This allows construction to be infallible and not write to the db.
218 pub fn new(
219 uri: &str,
220 persistent: bool,
221 connection_initialize_query: Option<&'static str>,
222 write_queue_constructor: Option<WriteQueueConstructor>,
223 ) -> Self {
224 let connection = Self {
225 uri: Arc::from(uri),
226 persistent,
227 connection_initialize_query,
228 connections: Default::default(),
229 };
230
231 connection.initialize_queues(write_queue_constructor);
232 connection
233 }
234}
235
236impl Deref for ThreadSafeConnection {
237 type Target = Connection;
238
239 fn deref(&self) -> &Self::Target {
240 self.connections.get_or(|| {
241 Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
242 })
243 }
244}
245
246pub fn background_thread_queue() -> WriteQueueConstructor {
247 use std::sync::mpsc::channel;
248
249 Box::new(|| {
250 let (sender, receiver) = channel::<QueuedWrite>();
251
252 thread::Builder::new()
253 .name("sqlezWorker".to_string())
254 .spawn(move || {
255 while let Ok(write) = receiver.recv() {
256 write()
257 }
258 })
259 .unwrap();
260
261 let sender = UnboundedSyncSender::new(sender);
262 Box::new(move |queued_write| {
263 sender
264 .send(queued_write)
265 .expect("Could not send write action to background thread");
266 })
267 })
268}
269
270pub fn locking_queue() -> WriteQueueConstructor {
271 Box::new(|| {
272 let write_mutex = Mutex::new(());
273 Box::new(move |queued_write| {
274 let _lock = write_mutex.lock();
275 queued_write();
276 })
277 })
278}
279
280#[cfg(test)]
281mod test {
282 use indoc::indoc;
283 use std::ops::Deref;
284
285 use std::thread;
286
287 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
288
289 #[test]
290 fn many_initialize_and_migrate_queries_at_once() {
291 let mut handles = vec![];
292
293 enum TestDomain {}
294 impl Domain for TestDomain {
295 const NAME: &str = "test";
296 const MIGRATIONS: &[&str] = &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"];
297 }
298
299 for _ in 0..100 {
300 handles.push(thread::spawn(|| {
301 let builder =
302 ThreadSafeConnection::builder::<TestDomain>("annoying-test.db", false)
303 .with_db_initialization_query("PRAGMA journal_mode=WAL")
304 .with_connection_initialize_query(indoc! {"
305 PRAGMA synchronous=NORMAL;
306 PRAGMA busy_timeout=1;
307 PRAGMA foreign_keys=TRUE;
308 PRAGMA case_sensitive_like=TRUE;
309 "});
310
311 let _ = smol::block_on(builder.build()).unwrap().deref();
312 }));
313 }
314
315 for handle in handles {
316 let _ = handle.join();
317 }
318 }
319
320 #[test]
321 #[should_panic]
322 fn wild_zed_lost_failure() {
323 enum TestWorkspace {}
324 impl Domain for TestWorkspace {
325 const NAME: &str = "workspace";
326
327 const MIGRATIONS: &[&str] = &["
328 CREATE TABLE workspaces(
329 workspace_id INTEGER PRIMARY KEY,
330 dock_visible INTEGER, -- Boolean
331 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
332 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
333 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
334 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
335 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
336 ) STRICT;
337
338 CREATE TABLE panes(
339 pane_id INTEGER PRIMARY KEY,
340 workspace_id INTEGER NOT NULL,
341 active INTEGER NOT NULL, -- Boolean
342 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
343 ON DELETE CASCADE
344 ON UPDATE CASCADE
345 ) STRICT;
346 "];
347 }
348
349 let builder =
350 ThreadSafeConnection::builder::<TestWorkspace>("wild_zed_lost_failure", false)
351 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
352
353 smol::block_on(builder.build()).unwrap();
354 }
355}