thread_safe_connection.rs

  1use anyhow::Context as _;
  2use collections::HashMap;
  3use futures::{Future, FutureExt, channel::oneshot};
  4use parking_lot::{Mutex, RwLock};
  5use std::{
  6    marker::PhantomData,
  7    ops::Deref,
  8    sync::{Arc, LazyLock},
  9    thread,
 10    time::Duration,
 11};
 12use thread_local::ThreadLocal;
 13
 14use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 15
 16const MIGRATION_RETRIES: usize = 10;
 17const CONNECTION_INITIALIZE_RETRIES: usize = 50;
 18const CONNECTION_INITIALIZE_RETRY_DELAY: Duration = Duration::from_millis(1);
 19
 20type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 21type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 22type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 23
 24/// List of queues of tasks by database uri. This lets us serialize writes to the database
 25/// and have a single worker thread per db file. This means many thread safe connections
 26/// (possibly with different migrations) could all be communicating with the same background
 27/// thread.
 28static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
 29
 30/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 31/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 32/// may be accessed by passing a callback to the `write` function which will queue the callback
 33#[derive(Clone)]
 34pub struct ThreadSafeConnection {
 35    uri: Arc<str>,
 36    persistent: bool,
 37    connection_initialize_query: Option<&'static str>,
 38    connections: Arc<ThreadLocal<Connection>>,
 39}
 40
 41unsafe impl Send for ThreadSafeConnection {}
 42unsafe impl Sync for ThreadSafeConnection {}
 43
 44pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 45    db_initialize_query: Option<&'static str>,
 46    write_queue_constructor: Option<WriteQueueConstructor>,
 47    connection: ThreadSafeConnection,
 48    _migrator: PhantomData<*mut M>,
 49}
 50
 51impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 52    /// Sets the query to run every time a connection is opened. This must
 53    /// be infallible (EG only use pragma statements) and not cause writes.
 54    /// to the db or it will panic.
 55    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 56        self.connection.connection_initialize_query = Some(initialize_query);
 57        self
 58    }
 59
 60    /// Queues an initialization query for the database file. This must be infallible
 61    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 62    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 63        self.db_initialize_query = Some(initialize_query);
 64        self
 65    }
 66
 67    /// Specifies how the thread safe connection should serialize writes. If provided
 68    /// the connection will call the write_queue_constructor for each database file in
 69    /// this process. The constructor is responsible for setting up a background thread or
 70    /// async task which handles queued writes with the provided connection.
 71    pub fn with_write_queue_constructor(
 72        mut self,
 73        write_queue_constructor: WriteQueueConstructor,
 74    ) -> Self {
 75        self.write_queue_constructor = Some(write_queue_constructor);
 76        self
 77    }
 78
 79    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection> {
 80        self.connection
 81            .initialize_queues(self.write_queue_constructor);
 82
 83        let db_initialize_query = self.db_initialize_query;
 84
 85        self.connection
 86            .write(move |connection| {
 87                if let Some(db_initialize_query) = db_initialize_query {
 88                    connection.exec(db_initialize_query).with_context(|| {
 89                        format!(
 90                            "Db initialize query failed to execute: {}",
 91                            db_initialize_query
 92                        )
 93                    })?()?;
 94                }
 95
 96                // Retry failed migrations in case they were run in parallel from different
 97                // processes. This gives a best attempt at migrating before bailing
 98                let mut migration_result =
 99                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
100
101                let foreign_keys_enabled: bool =
102                    connection.select_row::<i32>("PRAGMA foreign_keys")?()
103                        .unwrap_or(None)
104                        .map(|enabled| enabled != 0)
105                        .unwrap_or(false);
106
107                connection.exec("PRAGMA foreign_keys = OFF;")?()?;
108
109                for _ in 0..MIGRATION_RETRIES {
110                    migration_result = connection
111                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
112
113                    if migration_result.is_ok() {
114                        break;
115                    }
116                }
117
118                if foreign_keys_enabled {
119                    connection.exec("PRAGMA foreign_keys = ON;")?()?;
120                }
121                migration_result
122            })
123            .await?;
124
125        Ok(self.connection)
126    }
127}
128
129impl ThreadSafeConnection {
130    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
131        if !QUEUES.read().contains_key(&self.uri) {
132            let mut queues = QUEUES.write();
133            if !queues.contains_key(&self.uri) {
134                let mut write_queue_constructor =
135                    write_queue_constructor.unwrap_or_else(background_thread_queue);
136                queues.insert(self.uri.clone(), write_queue_constructor());
137                return true;
138            }
139        }
140        false
141    }
142
143    pub fn builder<M: Migrator>(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
144        ThreadSafeConnectionBuilder::<M> {
145            db_initialize_query: None,
146            write_queue_constructor: None,
147            connection: Self {
148                uri: Arc::from(uri),
149                persistent,
150                connection_initialize_query: None,
151                connections: Default::default(),
152            },
153            _migrator: PhantomData,
154        }
155    }
156
157    /// Opens a new db connection with the initialized file path. This is internal and only
158    /// called from the deref function.
159    fn open_file(uri: &str) -> Connection {
160        Connection::open_file(uri)
161    }
162
163    /// Opens a shared memory connection using the file path as the identifier. This is internal
164    /// and only called from the deref function.
165    fn open_shared_memory(uri: &str) -> Connection {
166        Connection::open_memory(Some(uri))
167    }
168
169    pub fn write<T: 'static + Send + Sync>(
170        &self,
171        callback: impl 'static + Send + FnOnce(&Connection) -> T,
172    ) -> impl Future<Output = T> {
173        // Check and invalidate queue and maybe recreate queue
174        let queues = QUEUES.read();
175        let write_channel = queues
176            .get(&self.uri)
177            .expect("Queues are inserted when build is called. This should always succeed");
178
179        // Create a one shot channel for the result of the queued write
180        // so we can await on the result
181        let (sender, receiver) = oneshot::channel();
182
183        let thread_safe_connection = (*self).clone();
184        write_channel(Box::new(move || {
185            let connection = thread_safe_connection.deref();
186            let result = connection.with_write(|connection| callback(connection));
187            sender.send(result).ok();
188        }));
189        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
190    }
191
192    pub(crate) fn create_connection(
193        persistent: bool,
194        uri: &str,
195        connection_initialize_query: Option<&'static str>,
196    ) -> Connection {
197        let mut connection = if persistent {
198            Self::open_file(uri)
199        } else {
200            Self::open_shared_memory(uri)
201        };
202
203        if let Some(initialize_query) = connection_initialize_query {
204            let mut last_error = None;
205            let initialized = (0..CONNECTION_INITIALIZE_RETRIES).any(|attempt| {
206                match connection
207                    .exec(initialize_query)
208                    .and_then(|mut statement| statement())
209                {
210                    Ok(()) => true,
211                    Err(err)
212                        if is_schema_lock_error(&err)
213                            && attempt + 1 < CONNECTION_INITIALIZE_RETRIES =>
214                    {
215                        last_error = Some(err);
216                        thread::sleep(CONNECTION_INITIALIZE_RETRY_DELAY);
217                        false
218                    }
219                    Err(err) => {
220                        panic!(
221                            "Initialize query failed to execute: {}\n\nCaused by:\n{err:#}",
222                            initialize_query
223                        )
224                    }
225                }
226            });
227
228            if !initialized {
229                let err = last_error
230                    .expect("connection initialization retries should record the last error");
231                panic!(
232                    "Initialize query failed to execute after retries: {}\n\nCaused by:\n{err:#}",
233                    initialize_query
234                );
235            }
236        }
237
238        // Disallow writes on the connection. The only writes allowed for thread safe connections
239        // are from the background thread that can serialize them.
240        *connection.write.get_mut() = false;
241
242        connection
243    }
244}
245
246fn is_schema_lock_error(err: &anyhow::Error) -> bool {
247    let message = format!("{err:#}");
248    message.contains("database schema is locked") || message.contains("database is locked")
249}
250
251impl ThreadSafeConnection {
252    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
253    /// This allows construction to be infallible and not write to the db.
254    pub fn new(
255        uri: &str,
256        persistent: bool,
257        connection_initialize_query: Option<&'static str>,
258        write_queue_constructor: Option<WriteQueueConstructor>,
259    ) -> Self {
260        let connection = Self {
261            uri: Arc::from(uri),
262            persistent,
263            connection_initialize_query,
264            connections: Default::default(),
265        };
266
267        connection.initialize_queues(write_queue_constructor);
268        connection
269    }
270}
271
272impl Deref for ThreadSafeConnection {
273    type Target = Connection;
274
275    fn deref(&self) -> &Self::Target {
276        self.connections.get_or(|| {
277            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
278        })
279    }
280}
281
282pub fn background_thread_queue() -> WriteQueueConstructor {
283    use std::sync::mpsc::channel;
284
285    Box::new(|| {
286        let (sender, receiver) = channel::<QueuedWrite>();
287
288        thread::Builder::new()
289            .name("sqlezWorker".to_string())
290            .spawn(move || {
291                while let Ok(write) = receiver.recv() {
292                    write()
293                }
294            })
295            .unwrap();
296
297        let sender = UnboundedSyncSender::new(sender);
298        Box::new(move |queued_write| {
299            sender
300                .send(queued_write)
301                .expect("Could not send write action to background thread");
302        })
303    })
304}
305
306pub fn locking_queue() -> WriteQueueConstructor {
307    Box::new(|| {
308        let write_mutex = Mutex::new(());
309        Box::new(move |queued_write| {
310            let _lock = write_mutex.lock();
311            queued_write();
312        })
313    })
314}
315
316#[cfg(test)]
317mod test {
318    use indoc::indoc;
319    use std::ops::Deref;
320
321    use std::{thread, time::Duration};
322
323    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
324
325    #[test]
326    fn many_initialize_and_migrate_queries_at_once() {
327        let mut handles = vec![];
328
329        enum TestDomain {}
330        impl Domain for TestDomain {
331            const NAME: &str = "test";
332            const MIGRATIONS: &[&str] = &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"];
333        }
334
335        for _ in 0..100 {
336            handles.push(thread::spawn(|| {
337                let builder =
338                    ThreadSafeConnection::builder::<TestDomain>("annoying-test.db", false)
339                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
340                        .with_connection_initialize_query(indoc! {"
341                                PRAGMA synchronous=NORMAL;
342                                PRAGMA busy_timeout=1;
343                                PRAGMA foreign_keys=TRUE;
344                                PRAGMA case_sensitive_like=TRUE;
345                            "});
346
347                let _ = smol::block_on(builder.build()).unwrap().deref();
348            }));
349        }
350
351        for handle in handles {
352            let _ = handle.join();
353        }
354    }
355
356    #[test]
357    fn connection_initialize_query_retries_transient_schema_lock() {
358        let name = "connection_initialize_query_retries_transient_schema_lock";
359        let locking_connection = crate::connection::Connection::open_memory(Some(name));
360        locking_connection.exec("BEGIN IMMEDIATE").unwrap()().unwrap();
361        locking_connection
362            .exec("CREATE TABLE test(col TEXT)")
363            .unwrap()()
364        .unwrap();
365
366        let releaser = thread::spawn(move || {
367            thread::sleep(Duration::from_millis(10));
368            locking_connection.exec("ROLLBACK").unwrap()().unwrap();
369        });
370
371        ThreadSafeConnection::create_connection(false, name, Some("PRAGMA FOREIGN_KEYS=true"));
372        releaser.join().unwrap();
373    }
374}