thread_safe_connection.rs

  1use futures::{channel::oneshot, Future, FutureExt};
  2use lazy_static::lazy_static;
  3use parking_lot::{Mutex, RwLock};
  4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  5use thread_local::ThreadLocal;
  6
  7use crate::{
  8    connection::Connection,
  9    domain::{Domain, Migrator},
 10    util::UnboundedSyncSender,
 11};
 12
 13const MIGRATION_RETRIES: usize = 10;
 14
 15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
 16type WriteQueueConstructor =
 17    Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
 18lazy_static! {
 19    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 20    /// and have a single worker thread per db file. This means many thread safe connections
 21    /// (possibly with different migrations) could all be communicating with the same background
 22    /// thread.
 23    static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
 24        Default::default();
 25}
 26
 27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 29/// may be accessed by passing a callback to the `write` function which will queue the callback
 30pub struct ThreadSafeConnection<M: Migrator = ()> {
 31    uri: Arc<str>,
 32    persistent: bool,
 33    connection_initialize_query: Option<&'static str>,
 34    connections: Arc<ThreadLocal<Connection>>,
 35    _migrator: PhantomData<M>,
 36}
 37
 38unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 39unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 40
 41pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
 42    db_initialize_query: Option<&'static str>,
 43    write_queue_constructor: Option<WriteQueueConstructor>,
 44    connection: ThreadSafeConnection<M>,
 45}
 46
 47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 48    /// Sets the query to run every time a connection is opened. This must
 49    /// be infallible (EG only use pragma statements) and not cause writes.
 50    /// to the db or it will panic.
 51    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 52        self.connection.connection_initialize_query = Some(initialize_query);
 53        self
 54    }
 55
 56    /// Specifies how the thread safe connection should serialize writes. If provided
 57    /// the connection will call the write_queue_constructor for each database file in
 58    /// this process. The constructor is responsible for setting up a background thread or
 59    /// async task which handles queued writes with the provided connection.
 60    pub fn with_write_queue_constructor(
 61        mut self,
 62        write_queue_constructor: WriteQueueConstructor,
 63    ) -> Self {
 64        self.write_queue_constructor = Some(write_queue_constructor);
 65        self
 66    }
 67
 68    /// Queues an initialization query for the database file. This must be infallible
 69    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 70    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 71        self.db_initialize_query = Some(initialize_query);
 72        self
 73    }
 74
 75    pub async fn build(self) -> ThreadSafeConnection<M> {
 76        self.connection
 77            .initialize_queues(self.write_queue_constructor);
 78
 79        let db_initialize_query = self.db_initialize_query;
 80
 81        self.connection
 82            .write(move |connection| {
 83                if let Some(db_initialize_query) = db_initialize_query {
 84                    connection.exec(db_initialize_query).expect(&format!(
 85                        "Db initialize query failed to execute: {}",
 86                        db_initialize_query
 87                    ))()
 88                    .unwrap();
 89                }
 90
 91                let mut failure_result = None;
 92                for _ in 0..MIGRATION_RETRIES {
 93                    failure_result = Some(M::migrate(connection));
 94                    if failure_result.as_ref().unwrap().is_ok() {
 95                        break;
 96                    }
 97                }
 98
 99                failure_result.unwrap().expect("Migration failed");
100            })
101            .await;
102
103        self.connection
104    }
105}
106
107impl<M: Migrator> ThreadSafeConnection<M> {
108    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) {
109        if !QUEUES.read().contains_key(&self.uri) {
110            let mut queues = QUEUES.write();
111            if !queues.contains_key(&self.uri) {
112                let mut write_connection = self.create_connection();
113                // Enable writes for this connection
114                write_connection.write = true;
115                if let Some(mut write_queue_constructor) = write_queue_constructor {
116                    let write_channel = write_queue_constructor(write_connection);
117                    queues.insert(self.uri.clone(), write_channel);
118                } else {
119                    use std::sync::mpsc::channel;
120
121                    let (sender, reciever) = channel::<QueuedWrite>();
122                    thread::spawn(move || {
123                        while let Ok(write) = reciever.recv() {
124                            write(&write_connection)
125                        }
126                    });
127
128                    let sender = UnboundedSyncSender::new(sender);
129                    queues.insert(
130                        self.uri.clone(),
131                        Box::new(move |queued_write| {
132                            sender
133                                .send(queued_write)
134                                .expect("Could not send write action to backgorund thread");
135                        }),
136                    );
137                }
138            }
139        }
140    }
141
142    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
143        ThreadSafeConnectionBuilder::<M> {
144            db_initialize_query: None,
145            write_queue_constructor: None,
146            connection: Self {
147                uri: Arc::from(uri),
148                persistent,
149                connection_initialize_query: None,
150                connections: Default::default(),
151                _migrator: PhantomData,
152            },
153        }
154    }
155
156    /// Opens a new db connection with the initialized file path. This is internal and only
157    /// called from the deref function.
158    fn open_file(&self) -> Connection {
159        Connection::open_file(self.uri.as_ref())
160    }
161
162    /// Opens a shared memory connection using the file path as the identifier. This is internal
163    /// and only called from the deref function.
164    fn open_shared_memory(&self) -> Connection {
165        Connection::open_memory(Some(self.uri.as_ref()))
166    }
167
168    pub fn write<T: 'static + Send + Sync>(
169        &self,
170        callback: impl 'static + Send + FnOnce(&Connection) -> T,
171    ) -> impl Future<Output = T> {
172        let queues = QUEUES.read();
173        let write_channel = queues
174            .get(&self.uri)
175            .expect("Queues are inserted when build is called. This should always succeed");
176
177        // Create a one shot channel for the result of the queued write
178        // so we can await on the result
179        let (sender, reciever) = oneshot::channel();
180        write_channel(Box::new(move |connection| {
181            sender.send(callback(connection)).ok();
182        }));
183        reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
184    }
185
186    pub(crate) fn create_connection(&self) -> Connection {
187        let mut connection = if self.persistent {
188            self.open_file()
189        } else {
190            self.open_shared_memory()
191        };
192
193        // Disallow writes on the connection. The only writes allowed for thread safe connections
194        // are from the background thread that can serialize them.
195        connection.write = false;
196
197        if let Some(initialize_query) = self.connection_initialize_query {
198            connection.exec(initialize_query).expect(&format!(
199                "Initialize query failed to execute: {}",
200                initialize_query
201            ))()
202            .unwrap()
203        }
204
205        connection
206    }
207}
208
209impl ThreadSafeConnection<()> {
210    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
211    /// This allows construction to be infallible and not write to the db.
212    pub fn new(
213        uri: &str,
214        persistent: bool,
215        connection_initialize_query: Option<&'static str>,
216        write_queue_constructor: Option<WriteQueueConstructor>,
217    ) -> Self {
218        let connection = Self {
219            uri: Arc::from(uri),
220            persistent,
221            connection_initialize_query,
222            connections: Default::default(),
223            _migrator: PhantomData,
224        };
225
226        connection.initialize_queues(write_queue_constructor);
227        connection
228    }
229}
230
231impl<D: Domain> Clone for ThreadSafeConnection<D> {
232    fn clone(&self) -> Self {
233        Self {
234            uri: self.uri.clone(),
235            persistent: self.persistent,
236            connection_initialize_query: self.connection_initialize_query.clone(),
237            connections: self.connections.clone(),
238            _migrator: PhantomData,
239        }
240    }
241}
242
243// TODO:
244//  1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
245//  2. If the new db also fails, downgrade to a shared in memory db
246//  3. In either case notify the user about what went wrong
247impl<M: Migrator> Deref for ThreadSafeConnection<M> {
248    type Target = Connection;
249
250    fn deref(&self) -> &Self::Target {
251        self.connections.get_or(|| self.create_connection())
252    }
253}
254
255pub fn locking_queue() -> WriteQueueConstructor {
256    Box::new(|connection| {
257        let connection = Mutex::new(connection);
258        Box::new(move |queued_write| {
259            let connection = connection.lock();
260            queued_write(&connection)
261        })
262    })
263}
264
265#[cfg(test)]
266mod test {
267    use indoc::indoc;
268    use lazy_static::__Deref;
269    use std::thread;
270
271    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
272
273    #[test]
274    fn many_initialize_and_migrate_queries_at_once() {
275        let mut handles = vec![];
276
277        enum TestDomain {}
278        impl Domain for TestDomain {
279            fn name() -> &'static str {
280                "test"
281            }
282            fn migrations() -> &'static [&'static str] {
283                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
284            }
285        }
286
287        for _ in 0..100 {
288            handles.push(thread::spawn(|| {
289                let builder =
290                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
291                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
292                        .with_connection_initialize_query(indoc! {"
293                                PRAGMA synchronous=NORMAL;
294                                PRAGMA busy_timeout=1;
295                                PRAGMA foreign_keys=TRUE;
296                                PRAGMA case_sensitive_like=TRUE;
297                            "});
298                let _ = smol::block_on(builder.build()).deref();
299            }));
300        }
301
302        for handle in handles {
303            let _ = handle.join();
304        }
305    }
306
307    #[test]
308    #[should_panic]
309    fn wild_zed_lost_failure() {
310        enum TestWorkspace {}
311        impl Domain for TestWorkspace {
312            fn name() -> &'static str {
313                "workspace"
314            }
315
316            fn migrations() -> &'static [&'static str] {
317                &["
318                    CREATE TABLE workspaces(
319                        workspace_id INTEGER PRIMARY KEY,
320                        dock_visible INTEGER, -- Boolean
321                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
322                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
323                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
324                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
325                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
326                    ) STRICT;
327                    
328                    CREATE TABLE panes(
329                        pane_id INTEGER PRIMARY KEY,
330                        workspace_id INTEGER NOT NULL,
331                        active INTEGER NOT NULL, -- Boolean
332                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
333                            ON DELETE CASCADE 
334                            ON UPDATE CASCADE
335                    ) STRICT;
336                "]
337            }
338        }
339
340        let builder =
341            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
342                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
343
344        smol::block_on(builder.build());
345    }
346}