@@ -4,6 +4,7 @@ pub mod kvp;
pub use anyhow;
pub use indoc::indoc;
pub use lazy_static;
+use parking_lot::Mutex;
pub use smol;
pub use sqlez;
pub use sqlez_macros;
@@ -59,6 +60,14 @@ pub async fn open_memory_db<M: Migrator>(db_name: &str) -> ThreadSafeConnection<
ThreadSafeConnection::<M>::builder(db_name, false)
.with_db_initialization_query(DB_INITIALIZE_QUERY)
.with_connection_initialize_query(CONNECTION_INITIALIZE_QUERY)
+ // Serialize queued writes via a mutex and run them synchronously
+ .with_write_queue_constructor(Box::new(|connection| {
+ let connection = Mutex::new(connection);
+ Box::new(move |queued_write| {
+ let connection = connection.lock();
+ queued_write(&connection)
+ })
+ }))
.build()
.await
}
@@ -13,12 +13,14 @@ use crate::{
const MIGRATION_RETRIES: usize = 10;
type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
+type WriteQueueConstructor =
+ Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
lazy_static! {
/// List of queues of tasks by database uri. This lets us serialize writes to the database
/// and have a single worker thread per db file. This means many thread safe connections
/// (possibly with different migrations) could all be communicating with the same background
/// thread.
- static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
+ static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
Default::default();
}
@@ -38,6 +40,7 @@ unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
db_initialize_query: Option<&'static str>,
+ write_queue_constructor: Option<WriteQueueConstructor>,
connection: ThreadSafeConnection<M>,
}
@@ -50,6 +53,18 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
self
}
+ /// Specifies how the thread safe connection should serialize writes. If provided
+ /// the connection will call the write_queue_constructor for each database file in
+ /// this process. The constructor is responsible for setting up a background thread or
+ /// async task which handles queued writes with the provided connection.
+ pub fn with_write_queue_constructor(
+ mut self,
+ write_queue_constructor: WriteQueueConstructor,
+ ) -> Self {
+ self.write_queue_constructor = Some(write_queue_constructor);
+ self
+ }
+
/// Queues an initialization query for the database file. This must be infallible
/// but may cause changes to the database file such as with `PRAGMA journal_mode`
pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
@@ -58,6 +73,38 @@ impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
}
pub async fn build(self) -> ThreadSafeConnection<M> {
+ if !QUEUES.read().contains_key(&self.connection.uri) {
+ let mut queues = QUEUES.write();
+ if !queues.contains_key(&self.connection.uri) {
+ let mut write_connection = self.connection.create_connection();
+ // Enable writes for this connection
+ write_connection.write = true;
+ if let Some(mut write_queue_constructor) = self.write_queue_constructor {
+ let write_channel = write_queue_constructor(write_connection);
+ queues.insert(self.connection.uri.clone(), write_channel);
+ } else {
+ use std::sync::mpsc::channel;
+
+ let (sender, reciever) = channel::<QueuedWrite>();
+ thread::spawn(move || {
+ while let Ok(write) = reciever.recv() {
+ write(&write_connection)
+ }
+ });
+
+ let sender = UnboundedSyncSender::new(sender);
+ queues.insert(
+ self.connection.uri.clone(),
+ Box::new(move |queued_write| {
+ sender
+ .send(queued_write)
+ .expect("Could not send write action to backgorund thread");
+ }),
+ );
+ }
+ }
+ }
+
let db_initialize_query = self.db_initialize_query;
self.connection
@@ -90,6 +137,7 @@ impl<M: Migrator> ThreadSafeConnection<M> {
pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
ThreadSafeConnectionBuilder::<M> {
db_initialize_query: None,
+ write_queue_constructor: None,
connection: Self {
uri: Arc::from(uri),
persistent,
@@ -112,48 +160,21 @@ impl<M: Migrator> ThreadSafeConnection<M> {
Connection::open_memory(Some(self.uri.as_ref()))
}
- fn queue_write_task(&self, callback: QueuedWrite) {
- // Startup write thread for this database if one hasn't already
- // been started and insert a channel to queue work for it
- if !QUEUES.read().contains_key(&self.uri) {
- let mut queues = QUEUES.write();
- if !queues.contains_key(&self.uri) {
- use std::sync::mpsc::channel;
-
- let (sender, reciever) = channel::<QueuedWrite>();
- let mut write_connection = self.create_connection();
- // Enable writes for this connection
- write_connection.write = true;
- thread::spawn(move || {
- while let Ok(write) = reciever.recv() {
- write(&write_connection)
- }
- });
-
- queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
- }
- }
-
- // Grab the queue for this database
- let queues = QUEUES.read();
- let write_channel = queues.get(&self.uri).unwrap();
-
- write_channel
- .send(callback)
- .expect("Could not send write action to backgorund thread");
- }
-
pub fn write<T: 'static + Send + Sync>(
&self,
callback: impl 'static + Send + FnOnce(&Connection) -> T,
) -> impl Future<Output = T> {
+ let queues = QUEUES.read();
+ let write_channel = queues
+ .get(&self.uri)
+ .expect("Queues are inserted when build is called. This should always succeed");
+
// Create a one shot channel for the result of the queued write
// so we can await on the result
let (sender, reciever) = oneshot::channel();
- self.queue_write_task(Box::new(move |connection| {
+ write_channel(Box::new(move |connection| {
sender.send(callback(connection)).ok();
}));
-
reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
}