thread_safe_connection.rs

  1use anyhow::Context;
  2use collections::HashMap;
  3use futures::{channel::oneshot, Future, FutureExt};
  4use lazy_static::lazy_static;
  5use parking_lot::{Mutex, RwLock};
  6use std::{marker::PhantomData, ops::Deref, sync::Arc, thread};
  7use thread_local::ThreadLocal;
  8
  9use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 10
 11const MIGRATION_RETRIES: usize = 10;
 12
 13type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 14type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 15type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 16lazy_static! {
 17    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 18    /// and have a single worker thread per db file. This means many thread safe connections
 19    /// (possibly with different migrations) could all be communicating with the same background
 20    /// thread.
 21    static ref QUEUES: RwLock<HashMap<Arc<str>, WriteQueue>> =
 22        Default::default();
 23}
 24
 25/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 26/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 27/// may be accessed by passing a callback to the `write` function which will queue the callback
 28pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
 29    uri: Arc<str>,
 30    persistent: bool,
 31    connection_initialize_query: Option<&'static str>,
 32    connections: Arc<ThreadLocal<Connection>>,
 33    _migrator: PhantomData<*mut M>,
 34}
 35
 36unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
 37unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
 38
 39pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 40    db_initialize_query: Option<&'static str>,
 41    write_queue_constructor: Option<WriteQueueConstructor>,
 42    connection: ThreadSafeConnection<M>,
 43}
 44
 45impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 46    /// Sets the query to run every time a connection is opened. This must
 47    /// be infallible (EG only use pragma statements) and not cause writes.
 48    /// to the db or it will panic.
 49    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 50        self.connection.connection_initialize_query = Some(initialize_query);
 51        self
 52    }
 53
 54    /// Queues an initialization query for the database file. This must be infallible
 55    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 56    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 57        self.db_initialize_query = Some(initialize_query);
 58        self
 59    }
 60
 61    /// Specifies how the thread safe connection should serialize writes. If provided
 62    /// the connection will call the write_queue_constructor for each database file in
 63    /// this process. The constructor is responsible for setting up a background thread or
 64    /// async task which handles queued writes with the provided connection.
 65    pub fn with_write_queue_constructor(
 66        mut self,
 67        write_queue_constructor: WriteQueueConstructor,
 68    ) -> Self {
 69        self.write_queue_constructor = Some(write_queue_constructor);
 70        self
 71    }
 72
 73    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
 74        self.connection
 75            .initialize_queues(self.write_queue_constructor);
 76
 77        let db_initialize_query = self.db_initialize_query;
 78
 79        self.connection
 80            .write(move |connection| {
 81                if let Some(db_initialize_query) = db_initialize_query {
 82                    connection.exec(db_initialize_query).with_context(|| {
 83                        format!(
 84                            "Db initialize query failed to execute: {}",
 85                            db_initialize_query
 86                        )
 87                    })?()?;
 88                }
 89
 90                // Retry failed migrations in case they were run in parallel from different
 91                // processes. This gives a best attempt at migrating before bailing
 92                let mut migration_result =
 93                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 94
 95                for _ in 0..MIGRATION_RETRIES {
 96                    migration_result = connection
 97                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
 98
 99                    if migration_result.is_ok() {
100                        break;
101                    }
102                }
103
104                migration_result
105            })
106            .await?;
107
108        Ok(self.connection)
109    }
110}
111
112impl<M: Migrator> ThreadSafeConnection<M> {
113    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
114        if !QUEUES.read().contains_key(&self.uri) {
115            let mut queues = QUEUES.write();
116            if !queues.contains_key(&self.uri) {
117                let mut write_queue_constructor =
118                    write_queue_constructor.unwrap_or_else(background_thread_queue);
119                queues.insert(self.uri.clone(), write_queue_constructor());
120                return true;
121            }
122        }
123        false
124    }
125
126    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
127        ThreadSafeConnectionBuilder::<M> {
128            db_initialize_query: None,
129            write_queue_constructor: None,
130            connection: Self {
131                uri: Arc::from(uri),
132                persistent,
133                connection_initialize_query: None,
134                connections: Default::default(),
135                _migrator: PhantomData,
136            },
137        }
138    }
139
140    /// Opens a new db connection with the initialized file path. This is internal and only
141    /// called from the deref function.
142    fn open_file(uri: &str) -> Connection {
143        Connection::open_file(uri)
144    }
145
146    /// Opens a shared memory connection using the file path as the identifier. This is internal
147    /// and only called from the deref function.
148    fn open_shared_memory(uri: &str) -> Connection {
149        Connection::open_memory(Some(uri))
150    }
151
152    pub fn write<T: 'static + Send + Sync>(
153        &self,
154        callback: impl 'static + Send + FnOnce(&Connection) -> T,
155    ) -> impl Future<Output = T> {
156        // Check and invalidate queue and maybe recreate queue
157        let queues = QUEUES.read();
158        let write_channel = queues
159            .get(&self.uri)
160            .expect("Queues are inserted when build is called. This should always succeed");
161
162        // Create a one shot channel for the result of the queued write
163        // so we can await on the result
164        let (sender, receiver) = oneshot::channel();
165
166        let thread_safe_connection = (*self).clone();
167        write_channel(Box::new(move || {
168            let connection = thread_safe_connection.deref();
169            let result = connection.with_write(|connection| callback(connection));
170            sender.send(result).ok();
171        }));
172        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
173    }
174
175    pub(crate) fn create_connection(
176        persistent: bool,
177        uri: &str,
178        connection_initialize_query: Option<&'static str>,
179    ) -> Connection {
180        let mut connection = if persistent {
181            Self::open_file(uri)
182        } else {
183            Self::open_shared_memory(uri)
184        };
185
186        // Disallow writes on the connection. The only writes allowed for thread safe connections
187        // are from the background thread that can serialize them.
188        *connection.write.get_mut() = false;
189
190        if let Some(initialize_query) = connection_initialize_query {
191            connection.exec(initialize_query).unwrap_or_else(|_| {
192                panic!("Initialize query failed to execute: {}", initialize_query)
193            })()
194            .unwrap()
195        }
196
197        connection
198    }
199}
200
201impl ThreadSafeConnection<()> {
202    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
203    /// This allows construction to be infallible and not write to the db.
204    pub fn new(
205        uri: &str,
206        persistent: bool,
207        connection_initialize_query: Option<&'static str>,
208        write_queue_constructor: Option<WriteQueueConstructor>,
209    ) -> Self {
210        let connection = Self {
211            uri: Arc::from(uri),
212            persistent,
213            connection_initialize_query,
214            connections: Default::default(),
215            _migrator: PhantomData,
216        };
217
218        connection.initialize_queues(write_queue_constructor);
219        connection
220    }
221}
222
223impl<M: Migrator> Clone for ThreadSafeConnection<M> {
224    fn clone(&self) -> Self {
225        Self {
226            uri: self.uri.clone(),
227            persistent: self.persistent,
228            connection_initialize_query: self.connection_initialize_query,
229            connections: self.connections.clone(),
230            _migrator: PhantomData,
231        }
232    }
233}
234
235impl<M: Migrator> Deref for ThreadSafeConnection<M> {
236    type Target = Connection;
237
238    fn deref(&self) -> &Self::Target {
239        self.connections.get_or(|| {
240            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
241        })
242    }
243}
244
245pub fn background_thread_queue() -> WriteQueueConstructor {
246    use std::sync::mpsc::channel;
247
248    Box::new(|| {
249        let (sender, receiver) = channel::<QueuedWrite>();
250
251        thread::spawn(move || {
252            while let Ok(write) = receiver.recv() {
253                write()
254            }
255        });
256
257        let sender = UnboundedSyncSender::new(sender);
258        Box::new(move |queued_write| {
259            sender
260                .send(queued_write)
261                .expect("Could not send write action to background thread");
262        })
263    })
264}
265
266pub fn locking_queue() -> WriteQueueConstructor {
267    Box::new(|| {
268        let write_mutex = Mutex::new(());
269        Box::new(move |queued_write| {
270            let _lock = write_mutex.lock();
271            queued_write();
272        })
273    })
274}
275
276#[cfg(test)]
277mod test {
278    use indoc::indoc;
279    use lazy_static::__Deref;
280
281    use std::thread;
282
283    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
284
285    #[test]
286    fn many_initialize_and_migrate_queries_at_once() {
287        let mut handles = vec![];
288
289        enum TestDomain {}
290        impl Domain for TestDomain {
291            fn name() -> &'static str {
292                "test"
293            }
294            fn migrations() -> &'static [&'static str] {
295                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
296            }
297        }
298
299        for _ in 0..100 {
300            handles.push(thread::spawn(|| {
301                let builder =
302                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
303                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
304                        .with_connection_initialize_query(indoc! {"
305                                PRAGMA synchronous=NORMAL;
306                                PRAGMA busy_timeout=1;
307                                PRAGMA foreign_keys=TRUE;
308                                PRAGMA case_sensitive_like=TRUE;
309                            "});
310
311                let _ = smol::block_on(builder.build()).unwrap().deref();
312            }));
313        }
314
315        for handle in handles {
316            let _ = handle.join();
317        }
318    }
319
320    #[test]
321    #[should_panic]
322    fn wild_zed_lost_failure() {
323        enum TestWorkspace {}
324        impl Domain for TestWorkspace {
325            fn name() -> &'static str {
326                "workspace"
327            }
328
329            fn migrations() -> &'static [&'static str] {
330                &["
331                    CREATE TABLE workspaces(
332                        workspace_id INTEGER PRIMARY KEY,
333                        dock_visible INTEGER, -- Boolean
334                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
335                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
336                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
337                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
338                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
339                    ) STRICT;
340
341                    CREATE TABLE panes(
342                        pane_id INTEGER PRIMARY KEY,
343                        workspace_id INTEGER NOT NULL,
344                        active INTEGER NOT NULL, -- Boolean
345                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
346                            ON DELETE CASCADE
347                            ON UPDATE CASCADE
348                    ) STRICT;
349                "]
350            }
351        }
352
353        let builder =
354            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
355                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
356
357        smol::block_on(builder.build()).unwrap();
358    }
359}