thread_safe_connection.rs

  1use anyhow::Context;
  2use futures::{channel::oneshot, Future, FutureExt};
  3use lazy_static::lazy_static;
  4use parking_lot::{Mutex, RwLock};
  5use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  6use thread_local::ThreadLocal;
  7
  8use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
  9
 10const MIGRATION_RETRIES: usize = 10;
 11
 12type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 13type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 14type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 15lazy_static! {
 16    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 17    /// and have a single worker thread per db file. This means many thread safe connections
 18    /// (possibly with different migrations) could all be communicating with the same background
 19    /// thread.
 20    static ref QUEUES: RwLock<HashMap<Arc<str>, WriteQueue>> =
 21        Default::default();
 22}
 23
 24/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 25/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 26/// may be accessed by passing a callback to the `write` function which will queue the callback
 27pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
 28    uri: Arc<str>,
 29    persistent: bool,
 30    connection_initialize_query: Option<&'static str>,
 31    connections: Arc<ThreadLocal<Connection>>,
 32    _migrator: PhantomData<*mut M>,
 33}
 34
 35unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
 36unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
 37
 38pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 39    db_initialize_query: Option<&'static str>,
 40    write_queue_constructor: Option<WriteQueueConstructor>,
 41    connection: ThreadSafeConnection<M>,
 42}
 43
 44impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 45    /// Sets the query to run every time a connection is opened. This must
 46    /// be infallible (EG only use pragma statements) and not cause writes.
 47    /// to the db or it will panic.
 48    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 49        self.connection.connection_initialize_query = Some(initialize_query);
 50        self
 51    }
 52
 53    /// Queues an initialization query for the database file. This must be infallible
 54    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 55    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 56        self.db_initialize_query = Some(initialize_query);
 57        self
 58    }
 59
 60    /// Specifies how the thread safe connection should serialize writes. If provided
 61    /// the connection will call the write_queue_constructor for each database file in
 62    /// this process. The constructor is responsible for setting up a background thread or
 63    /// async task which handles queued writes with the provided connection.
 64    pub fn with_write_queue_constructor(
 65        mut self,
 66        write_queue_constructor: WriteQueueConstructor,
 67    ) -> Self {
 68        self.write_queue_constructor = Some(write_queue_constructor);
 69        self
 70    }
 71
 72    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
 73        self.connection
 74            .initialize_queues(self.write_queue_constructor);
 75
 76        let db_initialize_query = self.db_initialize_query;
 77
 78        self.connection
 79            .write(move |connection| {
 80                if let Some(db_initialize_query) = db_initialize_query {
 81                    connection.exec(db_initialize_query).with_context(|| {
 82                        format!(
 83                            "Db initialize query failed to execute: {}",
 84                            db_initialize_query
 85                        )
 86                    })?()?;
 87                }
 88
 89                // Retry failed migrations in case they were run in parallel from different
 90                // processes. This gives a best attempt at migrating before bailing
 91                let mut migration_result =
 92                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 93
 94                for _ in 0..MIGRATION_RETRIES {
 95                    migration_result = connection
 96                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
 97
 98                    if migration_result.is_ok() {
 99                        break;
100                    }
101                }
102
103                migration_result
104            })
105            .await?;
106
107        Ok(self.connection)
108    }
109}
110
111impl<M: Migrator> ThreadSafeConnection<M> {
112    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
113        if !QUEUES.read().contains_key(&self.uri) {
114            let mut queues = QUEUES.write();
115            if !queues.contains_key(&self.uri) {
116                let mut write_queue_constructor =
117                    write_queue_constructor.unwrap_or_else(background_thread_queue);
118                queues.insert(self.uri.clone(), write_queue_constructor());
119                return true;
120            }
121        }
122        false
123    }
124
125    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
126        ThreadSafeConnectionBuilder::<M> {
127            db_initialize_query: None,
128            write_queue_constructor: None,
129            connection: Self {
130                uri: Arc::from(uri),
131                persistent,
132                connection_initialize_query: None,
133                connections: Default::default(),
134                _migrator: PhantomData,
135            },
136        }
137    }
138
139    /// Opens a new db connection with the initialized file path. This is internal and only
140    /// called from the deref function.
141    fn open_file(uri: &str) -> Connection {
142        Connection::open_file(uri)
143    }
144
145    /// Opens a shared memory connection using the file path as the identifier. This is internal
146    /// and only called from the deref function.
147    fn open_shared_memory(uri: &str) -> Connection {
148        Connection::open_memory(Some(uri))
149    }
150
151    pub fn write<T: 'static + Send + Sync>(
152        &self,
153        callback: impl 'static + Send + FnOnce(&Connection) -> T,
154    ) -> impl Future<Output = T> {
155        // Check and invalidate queue and maybe recreate queue
156        let queues = QUEUES.read();
157        let write_channel = queues
158            .get(&self.uri)
159            .expect("Queues are inserted when build is called. This should always succeed");
160
161        // Create a one shot channel for the result of the queued write
162        // so we can await on the result
163        let (sender, receiver) = oneshot::channel();
164
165        let thread_safe_connection = (*self).clone();
166        write_channel(Box::new(move || {
167            let connection = thread_safe_connection.deref();
168            let result = connection.with_write(|connection| callback(connection));
169            sender.send(result).ok();
170        }));
171        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
172    }
173
174    pub(crate) fn create_connection(
175        persistent: bool,
176        uri: &str,
177        connection_initialize_query: Option<&'static str>,
178    ) -> Connection {
179        let mut connection = if persistent {
180            Self::open_file(uri)
181        } else {
182            Self::open_shared_memory(uri)
183        };
184
185        // Disallow writes on the connection. The only writes allowed for thread safe connections
186        // are from the background thread that can serialize them.
187        *connection.write.get_mut() = false;
188
189        if let Some(initialize_query) = connection_initialize_query {
190            connection.exec(initialize_query).unwrap_or_else(|_| {
191                panic!("Initialize query failed to execute: {}", initialize_query)
192            })()
193            .unwrap()
194        }
195
196        connection
197    }
198}
199
200impl ThreadSafeConnection<()> {
201    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
202    /// This allows construction to be infallible and not write to the db.
203    pub fn new(
204        uri: &str,
205        persistent: bool,
206        connection_initialize_query: Option<&'static str>,
207        write_queue_constructor: Option<WriteQueueConstructor>,
208    ) -> Self {
209        let connection = Self {
210            uri: Arc::from(uri),
211            persistent,
212            connection_initialize_query,
213            connections: Default::default(),
214            _migrator: PhantomData,
215        };
216
217        connection.initialize_queues(write_queue_constructor);
218        connection
219    }
220}
221
222impl<M: Migrator> Clone for ThreadSafeConnection<M> {
223    fn clone(&self) -> Self {
224        Self {
225            uri: self.uri.clone(),
226            persistent: self.persistent,
227            connection_initialize_query: self.connection_initialize_query,
228            connections: self.connections.clone(),
229            _migrator: PhantomData,
230        }
231    }
232}
233
234impl<M: Migrator> Deref for ThreadSafeConnection<M> {
235    type Target = Connection;
236
237    fn deref(&self) -> &Self::Target {
238        self.connections.get_or(|| {
239            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
240        })
241    }
242}
243
244pub fn background_thread_queue() -> WriteQueueConstructor {
245    use std::sync::mpsc::channel;
246
247    Box::new(|| {
248        let (sender, receiver) = channel::<QueuedWrite>();
249
250        thread::spawn(move || {
251            while let Ok(write) = receiver.recv() {
252                write()
253            }
254        });
255
256        let sender = UnboundedSyncSender::new(sender);
257        Box::new(move |queued_write| {
258            sender
259                .send(queued_write)
260                .expect("Could not send write action to background thread");
261        })
262    })
263}
264
265pub fn locking_queue() -> WriteQueueConstructor {
266    Box::new(|| {
267        let write_mutex = Mutex::new(());
268        Box::new(move |queued_write| {
269            let _lock = write_mutex.lock();
270            queued_write();
271        })
272    })
273}
274
275#[cfg(test)]
276mod test {
277    use indoc::indoc;
278    use lazy_static::__Deref;
279
280    use std::thread;
281
282    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
283
284    #[test]
285    fn many_initialize_and_migrate_queries_at_once() {
286        let mut handles = vec![];
287
288        enum TestDomain {}
289        impl Domain for TestDomain {
290            fn name() -> &'static str {
291                "test"
292            }
293            fn migrations() -> &'static [&'static str] {
294                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
295            }
296        }
297
298        for _ in 0..100 {
299            handles.push(thread::spawn(|| {
300                let builder =
301                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
302                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
303                        .with_connection_initialize_query(indoc! {"
304                                PRAGMA synchronous=NORMAL;
305                                PRAGMA busy_timeout=1;
306                                PRAGMA foreign_keys=TRUE;
307                                PRAGMA case_sensitive_like=TRUE;
308                            "});
309
310                let _ = smol::block_on(builder.build()).unwrap().deref();
311            }));
312        }
313
314        for handle in handles {
315            let _ = handle.join();
316        }
317    }
318
319    #[test]
320    #[should_panic]
321    fn wild_zed_lost_failure() {
322        enum TestWorkspace {}
323        impl Domain for TestWorkspace {
324            fn name() -> &'static str {
325                "workspace"
326            }
327
328            fn migrations() -> &'static [&'static str] {
329                &["
330                    CREATE TABLE workspaces(
331                        workspace_id INTEGER PRIMARY KEY,
332                        dock_visible INTEGER, -- Boolean
333                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
334                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
335                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
336                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
337                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
338                    ) STRICT;
339
340                    CREATE TABLE panes(
341                        pane_id INTEGER PRIMARY KEY,
342                        workspace_id INTEGER NOT NULL,
343                        active INTEGER NOT NULL, -- Boolean
344                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
345                            ON DELETE CASCADE
346                            ON UPDATE CASCADE
347                    ) STRICT;
348                "]
349            }
350        }
351
352        let builder =
353            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
354                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
355
356        smol::block_on(builder.build()).unwrap();
357    }
358}