1use anyhow::Context as _;
2use collections::HashMap;
3use futures::{Future, FutureExt, channel::oneshot};
4use parking_lot::{Mutex, RwLock};
5use std::{
6 marker::PhantomData,
7 ops::Deref,
8 sync::{Arc, LazyLock},
9 thread,
10 time::Duration,
11};
12use thread_local::ThreadLocal;
13
14use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
15
16const MIGRATION_RETRIES: usize = 10;
17const CONNECTION_INITIALIZE_RETRIES: usize = 50;
18const CONNECTION_INITIALIZE_RETRY_DELAY: Duration = Duration::from_millis(1);
19
20type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
21type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
22type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
23
24/// List of queues of tasks by database uri. This lets us serialize writes to the database
25/// and have a single worker thread per db file. This means many thread safe connections
26/// (possibly with different migrations) could all be communicating with the same background
27/// thread.
28static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
29
30/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
31/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
32/// may be accessed by passing a callback to the `write` function which will queue the callback
33#[derive(Clone)]
34pub struct ThreadSafeConnection {
35 uri: Arc<str>,
36 persistent: bool,
37 connection_initialize_query: Option<&'static str>,
38 connections: Arc<ThreadLocal<Connection>>,
39}
40
41unsafe impl Send for ThreadSafeConnection {}
42unsafe impl Sync for ThreadSafeConnection {}
43
44pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
45 db_initialize_query: Option<&'static str>,
46 write_queue_constructor: Option<WriteQueueConstructor>,
47 connection: ThreadSafeConnection,
48 _migrator: PhantomData<*mut M>,
49}
50
51impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
52 /// Sets the query to run every time a connection is opened. This must
53 /// be infallible (EG only use pragma statements) and not cause writes.
54 /// to the db or it will panic.
55 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
56 self.connection.connection_initialize_query = Some(initialize_query);
57 self
58 }
59
60 /// Queues an initialization query for the database file. This must be infallible
61 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
62 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
63 self.db_initialize_query = Some(initialize_query);
64 self
65 }
66
67 /// Specifies how the thread safe connection should serialize writes. If provided
68 /// the connection will call the write_queue_constructor for each database file in
69 /// this process. The constructor is responsible for setting up a background thread or
70 /// async task which handles queued writes with the provided connection.
71 pub fn with_write_queue_constructor(
72 mut self,
73 write_queue_constructor: WriteQueueConstructor,
74 ) -> Self {
75 self.write_queue_constructor = Some(write_queue_constructor);
76 self
77 }
78
79 pub async fn build(self) -> anyhow::Result<ThreadSafeConnection> {
80 self.connection
81 .initialize_queues(self.write_queue_constructor);
82
83 let db_initialize_query = self.db_initialize_query;
84
85 self.connection
86 .write(move |connection| {
87 if let Some(db_initialize_query) = db_initialize_query {
88 connection.exec(db_initialize_query).with_context(|| {
89 format!(
90 "Db initialize query failed to execute: {}",
91 db_initialize_query
92 )
93 })?()?;
94 }
95
96 // Retry failed migrations in case they were run in parallel from different
97 // processes. This gives a best attempt at migrating before bailing
98 let mut migration_result =
99 anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
100
101 let foreign_keys_enabled: bool =
102 connection.select_row::<i32>("PRAGMA foreign_keys")?()
103 .unwrap_or(None)
104 .map(|enabled| enabled != 0)
105 .unwrap_or(false);
106
107 connection.exec("PRAGMA foreign_keys = OFF;")?()?;
108
109 for _ in 0..MIGRATION_RETRIES {
110 migration_result = connection
111 .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
112
113 if migration_result.is_ok() {
114 break;
115 }
116 }
117
118 if foreign_keys_enabled {
119 connection.exec("PRAGMA foreign_keys = ON;")?()?;
120 }
121 migration_result
122 })
123 .await?;
124
125 Ok(self.connection)
126 }
127}
128
129impl ThreadSafeConnection {
130 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
131 if !QUEUES.read().contains_key(&self.uri) {
132 let mut queues = QUEUES.write();
133 if !queues.contains_key(&self.uri) {
134 let mut write_queue_constructor =
135 write_queue_constructor.unwrap_or_else(background_thread_queue);
136 queues.insert(self.uri.clone(), write_queue_constructor());
137 return true;
138 }
139 }
140 false
141 }
142
143 pub fn builder<M: Migrator>(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
144 ThreadSafeConnectionBuilder::<M> {
145 db_initialize_query: None,
146 write_queue_constructor: None,
147 connection: Self {
148 uri: Arc::from(uri),
149 persistent,
150 connection_initialize_query: None,
151 connections: Default::default(),
152 },
153 _migrator: PhantomData,
154 }
155 }
156
157 /// Opens a new db connection with the initialized file path. This is internal and only
158 /// called from the deref function.
159 fn open_file(uri: &str) -> Connection {
160 Connection::open_file(uri)
161 }
162
163 /// Opens a shared memory connection using the file path as the identifier. This is internal
164 /// and only called from the deref function.
165 fn open_shared_memory(uri: &str) -> Connection {
166 Connection::open_memory(Some(uri))
167 }
168
169 pub fn write<T: 'static + Send + Sync>(
170 &self,
171 callback: impl 'static + Send + FnOnce(&Connection) -> T,
172 ) -> impl Future<Output = T> {
173 // Check and invalidate queue and maybe recreate queue
174 let queues = QUEUES.read();
175 let write_channel = queues
176 .get(&self.uri)
177 .expect("Queues are inserted when build is called. This should always succeed");
178
179 // Create a one shot channel for the result of the queued write
180 // so we can await on the result
181 let (sender, receiver) = oneshot::channel();
182
183 let thread_safe_connection = (*self).clone();
184 write_channel(Box::new(move || {
185 let connection = thread_safe_connection.deref();
186 let result = connection.with_write(|connection| callback(connection));
187 sender.send(result).ok();
188 }));
189 receiver.map(|response| response.expect("Write queue unexpectedly closed"))
190 }
191
192 pub(crate) fn create_connection(
193 persistent: bool,
194 uri: &str,
195 connection_initialize_query: Option<&'static str>,
196 ) -> Connection {
197 let mut connection = if persistent {
198 Self::open_file(uri)
199 } else {
200 Self::open_shared_memory(uri)
201 };
202
203 if let Some(initialize_query) = connection_initialize_query {
204 let mut last_error = None;
205 let initialized = (0..CONNECTION_INITIALIZE_RETRIES).any(|attempt| {
206 match connection
207 .exec(initialize_query)
208 .and_then(|mut statement| statement())
209 {
210 Ok(()) => true,
211 Err(err)
212 if is_schema_lock_error(&err)
213 && attempt + 1 < CONNECTION_INITIALIZE_RETRIES =>
214 {
215 last_error = Some(err);
216 thread::sleep(CONNECTION_INITIALIZE_RETRY_DELAY);
217 false
218 }
219 Err(err) => {
220 panic!(
221 "Initialize query failed to execute: {}\n\nCaused by:\n{err:#}",
222 initialize_query
223 )
224 }
225 }
226 });
227
228 if !initialized {
229 let err = last_error
230 .expect("connection initialization retries should record the last error");
231 panic!(
232 "Initialize query failed to execute after retries: {}\n\nCaused by:\n{err:#}",
233 initialize_query
234 );
235 }
236 }
237
238 // Disallow writes on the connection. The only writes allowed for thread safe connections
239 // are from the background thread that can serialize them.
240 *connection.write.get_mut() = false;
241
242 connection
243 }
244}
245
246fn is_schema_lock_error(err: &anyhow::Error) -> bool {
247 let message = format!("{err:#}");
248 message.contains("database schema is locked") || message.contains("database is locked")
249}
250
251impl ThreadSafeConnection {
252 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
253 /// This allows construction to be infallible and not write to the db.
254 pub fn new(
255 uri: &str,
256 persistent: bool,
257 connection_initialize_query: Option<&'static str>,
258 write_queue_constructor: Option<WriteQueueConstructor>,
259 ) -> Self {
260 let connection = Self {
261 uri: Arc::from(uri),
262 persistent,
263 connection_initialize_query,
264 connections: Default::default(),
265 };
266
267 connection.initialize_queues(write_queue_constructor);
268 connection
269 }
270}
271
272impl Deref for ThreadSafeConnection {
273 type Target = Connection;
274
275 fn deref(&self) -> &Self::Target {
276 self.connections.get_or(|| {
277 Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
278 })
279 }
280}
281
282pub fn background_thread_queue() -> WriteQueueConstructor {
283 use std::sync::mpsc::channel;
284
285 Box::new(|| {
286 let (sender, receiver) = channel::<QueuedWrite>();
287
288 thread::Builder::new()
289 .name("sqlezWorker".to_string())
290 .spawn(move || {
291 while let Ok(write) = receiver.recv() {
292 write()
293 }
294 })
295 .unwrap();
296
297 let sender = UnboundedSyncSender::new(sender);
298 Box::new(move |queued_write| {
299 sender
300 .send(queued_write)
301 .expect("Could not send write action to background thread");
302 })
303 })
304}
305
306pub fn locking_queue() -> WriteQueueConstructor {
307 Box::new(|| {
308 let write_mutex = Mutex::new(());
309 Box::new(move |queued_write| {
310 let _lock = write_mutex.lock();
311 queued_write();
312 })
313 })
314}
315
316#[cfg(test)]
317mod test {
318 use indoc::indoc;
319 use std::ops::Deref;
320
321 use std::{thread, time::Duration};
322
323 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
324
325 #[test]
326 fn many_initialize_and_migrate_queries_at_once() {
327 let mut handles = vec![];
328
329 enum TestDomain {}
330 impl Domain for TestDomain {
331 const NAME: &str = "test";
332 const MIGRATIONS: &[&str] = &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"];
333 }
334
335 for _ in 0..100 {
336 handles.push(thread::spawn(|| {
337 let builder =
338 ThreadSafeConnection::builder::<TestDomain>("annoying-test.db", false)
339 .with_db_initialization_query("PRAGMA journal_mode=WAL")
340 .with_connection_initialize_query(indoc! {"
341 PRAGMA synchronous=NORMAL;
342 PRAGMA busy_timeout=1;
343 PRAGMA foreign_keys=TRUE;
344 PRAGMA case_sensitive_like=TRUE;
345 "});
346
347 let _ = smol::block_on(builder.build()).unwrap().deref();
348 }));
349 }
350
351 for handle in handles {
352 let _ = handle.join();
353 }
354 }
355
356 #[test]
357 fn connection_initialize_query_retries_transient_schema_lock() {
358 let name = "connection_initialize_query_retries_transient_schema_lock";
359 let locking_connection = crate::connection::Connection::open_memory(Some(name));
360 locking_connection.exec("BEGIN IMMEDIATE").unwrap()().unwrap();
361 locking_connection
362 .exec("CREATE TABLE test(col TEXT)")
363 .unwrap()()
364 .unwrap();
365
366 let releaser = thread::spawn(move || {
367 thread::sleep(Duration::from_millis(10));
368 locking_connection.exec("ROLLBACK").unwrap()().unwrap();
369 });
370
371 ThreadSafeConnection::create_connection(false, name, Some("PRAGMA FOREIGN_KEYS=true"));
372 releaser.join().unwrap();
373 }
374}