thread_safe_connection.rs

  1use futures::{channel::oneshot, Future, FutureExt};
  2use lazy_static::lazy_static;
  3use parking_lot::RwLock;
  4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  5use thread_local::ThreadLocal;
  6
  7use crate::{
  8    connection::Connection,
  9    domain::{Domain, Migrator},
 10    util::UnboundedSyncSender,
 11};
 12
 13const MIGRATION_RETRIES: usize = 10;
 14
 15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
 16lazy_static! {
 17    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 18    /// and have a single worker thread per db file. This means many thread safe connections
 19    /// (possibly with different migrations) could all be communicating with the same background
 20    /// thread.
 21    static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
 22        Default::default();
 23}
 24
 25/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 26/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 27/// may be accessed by passing a callback to the `write` function which will queue the callback
 28pub struct ThreadSafeConnection<M: Migrator = ()> {
 29    uri: Arc<str>,
 30    persistent: bool,
 31    connection_initialize_query: Option<&'static str>,
 32    connections: Arc<ThreadLocal<Connection>>,
 33    _migrator: PhantomData<M>,
 34}
 35
 36unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 37unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 38
 39pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
 40    db_initialize_query: Option<&'static str>,
 41    connection: ThreadSafeConnection<M>,
 42}
 43
 44impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 45    /// Sets the query to run every time a connection is opened. This must
 46    /// be infallible (EG only use pragma statements) and not cause writes.
 47    /// to the db or it will panic.
 48    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 49        self.connection.connection_initialize_query = Some(initialize_query);
 50        self
 51    }
 52
 53    /// Queues an initialization query for the database file. This must be infallible
 54    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 55    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 56        self.db_initialize_query = Some(initialize_query);
 57        self
 58    }
 59
 60    pub async fn build(self) -> ThreadSafeConnection<M> {
 61        let db_initialize_query = self.db_initialize_query;
 62
 63        self.connection
 64            .write(move |connection| {
 65                if let Some(db_initialize_query) = db_initialize_query {
 66                    connection.exec(db_initialize_query).expect(&format!(
 67                        "Db initialize query failed to execute: {}",
 68                        db_initialize_query
 69                    ))()
 70                    .unwrap();
 71                }
 72
 73                let mut failure_result = None;
 74                for _ in 0..MIGRATION_RETRIES {
 75                    failure_result = Some(M::migrate(connection));
 76                    if failure_result.as_ref().unwrap().is_ok() {
 77                        break;
 78                    }
 79                }
 80
 81                failure_result.unwrap().expect("Migration failed");
 82            })
 83            .await;
 84
 85        self.connection
 86    }
 87}
 88
 89impl<M: Migrator> ThreadSafeConnection<M> {
 90    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
 91        ThreadSafeConnectionBuilder::<M> {
 92            db_initialize_query: None,
 93            connection: Self {
 94                uri: Arc::from(uri),
 95                persistent,
 96                connection_initialize_query: None,
 97                connections: Default::default(),
 98                _migrator: PhantomData,
 99            },
100        }
101    }
102
103    /// Opens a new db connection with the initialized file path. This is internal and only
104    /// called from the deref function.
105    fn open_file(&self) -> Connection {
106        Connection::open_file(self.uri.as_ref())
107    }
108
109    /// Opens a shared memory connection using the file path as the identifier. This is internal
110    /// and only called from the deref function.
111    fn open_shared_memory(&self) -> Connection {
112        Connection::open_memory(Some(self.uri.as_ref()))
113    }
114
115    fn queue_write_task(&self, callback: QueuedWrite) {
116        // Startup write thread for this database if one hasn't already
117        // been started and insert a channel to queue work for it
118        if !QUEUES.read().contains_key(&self.uri) {
119            let mut queues = QUEUES.write();
120            if !queues.contains_key(&self.uri) {
121                use std::sync::mpsc::channel;
122
123                let (sender, reciever) = channel::<QueuedWrite>();
124                let mut write_connection = self.create_connection();
125                // Enable writes for this connection
126                write_connection.write = true;
127                thread::spawn(move || {
128                    while let Ok(write) = reciever.recv() {
129                        write(&write_connection)
130                    }
131                });
132
133                queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
134            }
135        }
136
137        // Grab the queue for this database
138        let queues = QUEUES.read();
139        let write_channel = queues.get(&self.uri).unwrap();
140
141        write_channel
142            .send(callback)
143            .expect("Could not send write action to backgorund thread");
144    }
145
146    pub fn write<T: 'static + Send + Sync>(
147        &self,
148        callback: impl 'static + Send + FnOnce(&Connection) -> T,
149    ) -> impl Future<Output = T> {
150        // Create a one shot channel for the result of the queued write
151        // so we can await on the result
152        let (sender, reciever) = oneshot::channel();
153        self.queue_write_task(Box::new(move |connection| {
154            sender.send(callback(connection)).ok();
155        }));
156
157        reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
158    }
159
160    pub(crate) fn create_connection(&self) -> Connection {
161        let mut connection = if self.persistent {
162            self.open_file()
163        } else {
164            self.open_shared_memory()
165        };
166
167        // Disallow writes on the connection. The only writes allowed for thread safe connections
168        // are from the background thread that can serialize them.
169        connection.write = false;
170
171        if let Some(initialize_query) = self.connection_initialize_query {
172            connection.exec(initialize_query).expect(&format!(
173                "Initialize query failed to execute: {}",
174                initialize_query
175            ))()
176            .unwrap()
177        }
178
179        connection
180    }
181}
182
183impl ThreadSafeConnection<()> {
184    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
185    /// This allows construction to be infallible and not write to the db.
186    pub fn new(
187        uri: &str,
188        persistent: bool,
189        connection_initialize_query: Option<&'static str>,
190    ) -> Self {
191        Self {
192            uri: Arc::from(uri),
193            persistent,
194            connection_initialize_query,
195            connections: Default::default(),
196            _migrator: PhantomData,
197        }
198    }
199}
200
201impl<D: Domain> Clone for ThreadSafeConnection<D> {
202    fn clone(&self) -> Self {
203        Self {
204            uri: self.uri.clone(),
205            persistent: self.persistent,
206            connection_initialize_query: self.connection_initialize_query.clone(),
207            connections: self.connections.clone(),
208            _migrator: PhantomData,
209        }
210    }
211}
212
213// TODO:
214//  1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
215//  2. If the new db also fails, downgrade to a shared in memory db
216//  3. In either case notify the user about what went wrong
217impl<M: Migrator> Deref for ThreadSafeConnection<M> {
218    type Target = Connection;
219
220    fn deref(&self) -> &Self::Target {
221        self.connections.get_or(|| self.create_connection())
222    }
223}
224
225#[cfg(test)]
226mod test {
227    use indoc::indoc;
228    use lazy_static::__Deref;
229    use std::thread;
230
231    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
232
233    #[test]
234    fn many_initialize_and_migrate_queries_at_once() {
235        let mut handles = vec![];
236
237        enum TestDomain {}
238        impl Domain for TestDomain {
239            fn name() -> &'static str {
240                "test"
241            }
242            fn migrations() -> &'static [&'static str] {
243                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
244            }
245        }
246
247        for _ in 0..100 {
248            handles.push(thread::spawn(|| {
249                let builder =
250                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
251                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
252                        .with_connection_initialize_query(indoc! {"
253                                PRAGMA synchronous=NORMAL;
254                                PRAGMA busy_timeout=1;
255                                PRAGMA foreign_keys=TRUE;
256                                PRAGMA case_sensitive_like=TRUE;
257                            "});
258                let _ = smol::block_on(builder.build()).deref();
259            }));
260        }
261
262        for handle in handles {
263            let _ = handle.join();
264        }
265    }
266
267    #[test]
268    #[should_panic]
269    fn wild_zed_lost_failure() {
270        enum TestWorkspace {}
271        impl Domain for TestWorkspace {
272            fn name() -> &'static str {
273                "workspace"
274            }
275
276            fn migrations() -> &'static [&'static str] {
277                &["
278                    CREATE TABLE workspaces(
279                        workspace_id INTEGER PRIMARY KEY,
280                        dock_visible INTEGER, -- Boolean
281                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
282                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
283                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
284                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
285                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
286                    ) STRICT;
287                    
288                    CREATE TABLE panes(
289                        pane_id INTEGER PRIMARY KEY,
290                        workspace_id INTEGER NOT NULL,
291                        active INTEGER NOT NULL, -- Boolean
292                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
293                            ON DELETE CASCADE 
294                            ON UPDATE CASCADE
295                    ) STRICT;
296                "]
297            }
298        }
299
300        let builder =
301            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
302                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
303
304        smol::block_on(builder.build());
305    }
306}