thread_safe_connection.rs

  1use anyhow::Context;
  2use collections::HashMap;
  3use futures::{channel::oneshot, Future, FutureExt};
  4use parking_lot::{Mutex, RwLock};
  5use std::{
  6    marker::PhantomData,
  7    ops::Deref,
  8    sync::{Arc, LazyLock},
  9    thread,
 10};
 11use thread_local::ThreadLocal;
 12
 13use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 14
 15const MIGRATION_RETRIES: usize = 10;
 16
 17type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 18type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 19type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 20
 21/// List of queues of tasks by database uri. This lets us serialize writes to the database
 22/// and have a single worker thread per db file. This means many thread safe connections
 23/// (possibly with different migrations) could all be communicating with the same background
 24/// thread.
 25static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> =
 26    LazyLock::new(|| Default::default());
 27
 28/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 29/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 30/// may be accessed by passing a callback to the `write` function which will queue the callback
 31pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
 32    uri: Arc<str>,
 33    persistent: bool,
 34    connection_initialize_query: Option<&'static str>,
 35    connections: Arc<ThreadLocal<Connection>>,
 36    _migrator: PhantomData<*mut M>,
 37}
 38
 39unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
 40unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
 41
 42pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 43    db_initialize_query: Option<&'static str>,
 44    write_queue_constructor: Option<WriteQueueConstructor>,
 45    connection: ThreadSafeConnection<M>,
 46}
 47
 48impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 49    /// Sets the query to run every time a connection is opened. This must
 50    /// be infallible (EG only use pragma statements) and not cause writes.
 51    /// to the db or it will panic.
 52    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 53        self.connection.connection_initialize_query = Some(initialize_query);
 54        self
 55    }
 56
 57    /// Queues an initialization query for the database file. This must be infallible
 58    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 59    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 60        self.db_initialize_query = Some(initialize_query);
 61        self
 62    }
 63
 64    /// Specifies how the thread safe connection should serialize writes. If provided
 65    /// the connection will call the write_queue_constructor for each database file in
 66    /// this process. The constructor is responsible for setting up a background thread or
 67    /// async task which handles queued writes with the provided connection.
 68    pub fn with_write_queue_constructor(
 69        mut self,
 70        write_queue_constructor: WriteQueueConstructor,
 71    ) -> Self {
 72        self.write_queue_constructor = Some(write_queue_constructor);
 73        self
 74    }
 75
 76    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
 77        self.connection
 78            .initialize_queues(self.write_queue_constructor);
 79
 80        let db_initialize_query = self.db_initialize_query;
 81
 82        self.connection
 83            .write(move |connection| {
 84                if let Some(db_initialize_query) = db_initialize_query {
 85                    connection.exec(db_initialize_query).with_context(|| {
 86                        format!(
 87                            "Db initialize query failed to execute: {}",
 88                            db_initialize_query
 89                        )
 90                    })?()?;
 91                }
 92
 93                // Retry failed migrations in case they were run in parallel from different
 94                // processes. This gives a best attempt at migrating before bailing
 95                let mut migration_result =
 96                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 97
 98                for _ in 0..MIGRATION_RETRIES {
 99                    migration_result = connection
100                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
101
102                    if migration_result.is_ok() {
103                        break;
104                    }
105                }
106
107                migration_result
108            })
109            .await?;
110
111        Ok(self.connection)
112    }
113}
114
115impl<M: Migrator> ThreadSafeConnection<M> {
116    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
117        if !QUEUES.read().contains_key(&self.uri) {
118            let mut queues = QUEUES.write();
119            if !queues.contains_key(&self.uri) {
120                let mut write_queue_constructor =
121                    write_queue_constructor.unwrap_or_else(background_thread_queue);
122                queues.insert(self.uri.clone(), write_queue_constructor());
123                return true;
124            }
125        }
126        false
127    }
128
129    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
130        ThreadSafeConnectionBuilder::<M> {
131            db_initialize_query: None,
132            write_queue_constructor: None,
133            connection: Self {
134                uri: Arc::from(uri),
135                persistent,
136                connection_initialize_query: None,
137                connections: Default::default(),
138                _migrator: PhantomData,
139            },
140        }
141    }
142
143    /// Opens a new db connection with the initialized file path. This is internal and only
144    /// called from the deref function.
145    fn open_file(uri: &str) -> Connection {
146        Connection::open_file(uri)
147    }
148
149    /// Opens a shared memory connection using the file path as the identifier. This is internal
150    /// and only called from the deref function.
151    fn open_shared_memory(uri: &str) -> Connection {
152        Connection::open_memory(Some(uri))
153    }
154
155    pub fn write<T: 'static + Send + Sync>(
156        &self,
157        callback: impl 'static + Send + FnOnce(&Connection) -> T,
158    ) -> impl Future<Output = T> {
159        // Check and invalidate queue and maybe recreate queue
160        let queues = QUEUES.read();
161        let write_channel = queues
162            .get(&self.uri)
163            .expect("Queues are inserted when build is called. This should always succeed");
164
165        // Create a one shot channel for the result of the queued write
166        // so we can await on the result
167        let (sender, receiver) = oneshot::channel();
168
169        let thread_safe_connection = (*self).clone();
170        write_channel(Box::new(move || {
171            let connection = thread_safe_connection.deref();
172            let result = connection.with_write(|connection| callback(connection));
173            sender.send(result).ok();
174        }));
175        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
176    }
177
178    pub(crate) fn create_connection(
179        persistent: bool,
180        uri: &str,
181        connection_initialize_query: Option<&'static str>,
182    ) -> Connection {
183        let mut connection = if persistent {
184            Self::open_file(uri)
185        } else {
186            Self::open_shared_memory(uri)
187        };
188
189        // Disallow writes on the connection. The only writes allowed for thread safe connections
190        // are from the background thread that can serialize them.
191        *connection.write.get_mut() = false;
192
193        if let Some(initialize_query) = connection_initialize_query {
194            connection.exec(initialize_query).unwrap_or_else(|_| {
195                panic!("Initialize query failed to execute: {}", initialize_query)
196            })()
197            .unwrap()
198        }
199
200        connection
201    }
202}
203
204impl ThreadSafeConnection<()> {
205    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
206    /// This allows construction to be infallible and not write to the db.
207    pub fn new(
208        uri: &str,
209        persistent: bool,
210        connection_initialize_query: Option<&'static str>,
211        write_queue_constructor: Option<WriteQueueConstructor>,
212    ) -> Self {
213        let connection = Self {
214            uri: Arc::from(uri),
215            persistent,
216            connection_initialize_query,
217            connections: Default::default(),
218            _migrator: PhantomData,
219        };
220
221        connection.initialize_queues(write_queue_constructor);
222        connection
223    }
224}
225
226impl<M: Migrator> Clone for ThreadSafeConnection<M> {
227    fn clone(&self) -> Self {
228        Self {
229            uri: self.uri.clone(),
230            persistent: self.persistent,
231            connection_initialize_query: self.connection_initialize_query,
232            connections: self.connections.clone(),
233            _migrator: PhantomData,
234        }
235    }
236}
237
238impl<M: Migrator> Deref for ThreadSafeConnection<M> {
239    type Target = Connection;
240
241    fn deref(&self) -> &Self::Target {
242        self.connections.get_or(|| {
243            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
244        })
245    }
246}
247
248pub fn background_thread_queue() -> WriteQueueConstructor {
249    use std::sync::mpsc::channel;
250
251    Box::new(|| {
252        let (sender, receiver) = channel::<QueuedWrite>();
253
254        thread::spawn(move || {
255            while let Ok(write) = receiver.recv() {
256                write()
257            }
258        });
259
260        let sender = UnboundedSyncSender::new(sender);
261        Box::new(move |queued_write| {
262            sender
263                .send(queued_write)
264                .expect("Could not send write action to background thread");
265        })
266    })
267}
268
269pub fn locking_queue() -> WriteQueueConstructor {
270    Box::new(|| {
271        let write_mutex = Mutex::new(());
272        Box::new(move |queued_write| {
273            let _lock = write_mutex.lock();
274            queued_write();
275        })
276    })
277}
278
279#[cfg(test)]
280mod test {
281    use indoc::indoc;
282    use std::ops::Deref;
283
284    use std::thread;
285
286    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
287
288    #[test]
289    fn many_initialize_and_migrate_queries_at_once() {
290        let mut handles = vec![];
291
292        enum TestDomain {}
293        impl Domain for TestDomain {
294            fn name() -> &'static str {
295                "test"
296            }
297            fn migrations() -> &'static [&'static str] {
298                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
299            }
300        }
301
302        for _ in 0..100 {
303            handles.push(thread::spawn(|| {
304                let builder =
305                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
306                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
307                        .with_connection_initialize_query(indoc! {"
308                                PRAGMA synchronous=NORMAL;
309                                PRAGMA busy_timeout=1;
310                                PRAGMA foreign_keys=TRUE;
311                                PRAGMA case_sensitive_like=TRUE;
312                            "});
313
314                let _ = smol::block_on(builder.build()).unwrap().deref();
315            }));
316        }
317
318        for handle in handles {
319            let _ = handle.join();
320        }
321    }
322
323    #[test]
324    #[should_panic]
325    fn wild_zed_lost_failure() {
326        enum TestWorkspace {}
327        impl Domain for TestWorkspace {
328            fn name() -> &'static str {
329                "workspace"
330            }
331
332            fn migrations() -> &'static [&'static str] {
333                &["
334                    CREATE TABLE workspaces(
335                        workspace_id INTEGER PRIMARY KEY,
336                        dock_visible INTEGER, -- Boolean
337                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
338                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
339                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
340                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
341                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
342                    ) STRICT;
343
344                    CREATE TABLE panes(
345                        pane_id INTEGER PRIMARY KEY,
346                        workspace_id INTEGER NOT NULL,
347                        active INTEGER NOT NULL, -- Boolean
348                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
349                            ON DELETE CASCADE
350                            ON UPDATE CASCADE
351                    ) STRICT;
352                "]
353            }
354        }
355
356        let builder =
357            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
358                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
359
360        smol::block_on(builder.build()).unwrap();
361    }
362}