fix test failures

Kay Simmons created

Change summary

crates/db/src/db.rs                        |  9 ++
crates/sqlez/src/thread_safe_connection.rs | 89 ++++++++++++++---------
2 files changed, 64 insertions(+), 34 deletions(-)

Detailed changes

crates/db/src/db.rs 🔗

@@ -4,6 +4,7 @@ pub mod kvp;
 pub use anyhow;
 pub use indoc::indoc;
 pub use lazy_static;
+use parking_lot::Mutex;
 pub use smol;
 pub use sqlez;
 pub use sqlez_macros;
@@ -59,6 +60,14 @@ pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<
     ThreadSafeConnection::<M>::builder(db_name, false)
         .with_db_initialization_query(DB_INITIALIZE_QUERY)
         .with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+        // Serialize queued writes via a mutex and run them synchronously
+        .with_write_queue_constructor(Box::new(|connection| {
+            let connection = Mutex::new(connection);
+            Box::new(move |queued_write| {
+                let connection = connection.lock();
+                queued_write(&connection)
+            })
+        }))
         .build()
         .await
 }

crates/sqlez/src/thread_safe_connection.rs 🔗

@@ -13,12 +13,14 @@ use crate::{
 const MIGRATION_RETRIES: usize = 10;
 
 type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
+type WriteQueueConstructor =
+    Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
 lazy_static! {
     /// List of queues of tasks by database uri. This lets us serialize writes to the database
     /// and have a single worker thread per db file. This means many thread safe connections
     /// (possibly with different migrations) could all be communicating with the same background
     /// thread.
-    static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
+    static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
         Default::default();
 }
 
@@ -38,6 +40,7 @@ unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 
 pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
     db_initialize_query: Option<&'static str>,
+    write_queue_constructor: Option<WriteQueueConstructor>,
     connection: ThreadSafeConnection<M>,
 }
 
@@ -50,6 +53,18 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
         self
     }
 
+    /// Specifies how the thread safe connection should serialize writes. If provided
+    /// the connection will call the write_queue_constructor for each database file in
+    /// this process. The constructor is responsible for setting up a background thread or
+    /// async task which handles queued writes with the provided connection.
+    pub fn with_write_queue_constructor(
+        mut self,
+        write_queue_constructor: WriteQueueConstructor,
+    ) -> Self {
+        self.write_queue_constructor = Some(write_queue_constructor);
+        self
+    }
+
     /// Queues an initialization query for the database file. This must be infallible
     /// but may cause changes to the database file such as with `PRAGMA journal_mode`
     pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
@@ -58,6 +73,38 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
     }
 
     pub async fn build(self) -> ThreadSafeConnection<M> {
+        if !QUEUES.read().contains_key(&self.connection.uri) {
+            let mut queues = QUEUES.write();
+            if !queues.contains_key(&self.connection.uri) {
+                let mut write_connection = self.connection.create_connection();
+                // Enable writes for this connection
+                write_connection.write = true;
+                if let Some(mut write_queue_constructor) = self.write_queue_constructor {
+                    let write_channel = write_queue_constructor(write_connection);
+                    queues.insert(self.connection.uri.clone(), write_channel);
+                } else {
+                    use std::sync::mpsc::channel;
+
+                    let (sender, reciever) = channel::<QueuedWrite>();
+                    thread::spawn(move || {
+                        while let Ok(write) = reciever.recv() {
+                            write(&write_connection)
+                        }
+                    });
+
+                    let sender = UnboundedSyncSender::new(sender);
+                    queues.insert(
+                        self.connection.uri.clone(),
+                        Box::new(move |queued_write| {
+                            sender
+                                .send(queued_write)
+                                .expect("Could not send write action to backgorund thread");
+                        }),
+                    );
+                }
+            }
+        }
+
         let db_initialize_query = self.db_initialize_query;
 
         self.connection
@@ -90,6 +137,7 @@ impl<M: Migrator> ThreadSafeConnection<M> {
     pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
         ThreadSafeConnectionBuilder::<M> {
             db_initialize_query: None,
+            write_queue_constructor: None,
             connection: Self {
                 uri: Arc::from(uri),
                 persistent,
@@ -112,48 +160,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
         Connection::open_memory(Some(self.uri.as_ref()))
     }
 
-    fn queue_write_task(&self, callback: QueuedWrite) {
-        // Startup write thread for this database if one hasn't already
-        // been started and insert a channel to queue work for it
-        if !QUEUES.read().contains_key(&self.uri) {
-            let mut queues = QUEUES.write();
-            if !queues.contains_key(&self.uri) {
-                use std::sync::mpsc::channel;
-
-                let (sender, reciever) = channel::<QueuedWrite>();
-                let mut write_connection = self.create_connection();
-                // Enable writes for this connection
-                write_connection.write = true;
-                thread::spawn(move || {
-                    while let Ok(write) = reciever.recv() {
-                        write(&write_connection)
-                    }
-                });
-
-                queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
-            }
-        }
-
-        // Grab the queue for this database
-        let queues = QUEUES.read();
-        let write_channel = queues.get(&self.uri).unwrap();
-
-        write_channel
-            .send(callback)
-            .expect("Could not send write action to backgorund thread");
-    }
-
     pub fn write<T: 'static + Send + Sync>(
         &self,
         callback: impl 'static + Send + FnOnce(&Connection) -> T,
     ) -> impl Future<Output = T> {
+        let queues = QUEUES.read();
+        let write_channel = queues
+            .get(&self.uri)
+            .expect("Queues are inserted when build is called. This should always succeed");
+
         // Create a one shot channel for the result of the queued write
         // so we can await on the result
         let (sender, reciever) = oneshot::channel();
-        self.queue_write_task(Box::new(move |connection| {
+        write_channel(Box::new(move |connection| {
             sender.send(callback(connection)).ok();
         }));
-
         reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
     }