thread_safe_connection.rs

  1use futures::{channel::oneshot, Future, FutureExt};
  2use lazy_static::lazy_static;
  3use parking_lot::RwLock;
  4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  5use thread_local::ThreadLocal;
  6
  7use crate::{
  8    connection::Connection,
  9    domain::{Domain, Migrator},
 10    util::UnboundedSyncSender,
 11};
 12
 13const MIGRATION_RETRIES: usize = 10;
 14
 15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
 16type WriteQueueConstructor =
 17    Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
 18lazy_static! {
 19    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 20    /// and have a single worker thread per db file. This means many thread safe connections
 21    /// (possibly with different migrations) could all be communicating with the same background
 22    /// thread.
 23    static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
 24        Default::default();
 25}
 26
 27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 29/// may be accessed by passing a callback to the `write` function which will queue the callback
 30pub struct ThreadSafeConnection<M: Migrator = ()> {
 31    uri: Arc<str>,
 32    persistent: bool,
 33    connection_initialize_query: Option<&'static str>,
 34    connections: Arc<ThreadLocal<Connection>>,
 35    _migrator: PhantomData<M>,
 36}
 37
 38unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 39unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 40
 41pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
 42    db_initialize_query: Option<&'static str>,
 43    write_queue_constructor: Option<WriteQueueConstructor>,
 44    connection: ThreadSafeConnection<M>,
 45}
 46
 47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 48    /// Sets the query to run every time a connection is opened. This must
 49    /// be infallible (EG only use pragma statements) and not cause writes.
 50    /// to the db or it will panic.
 51    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 52        self.connection.connection_initialize_query = Some(initialize_query);
 53        self
 54    }
 55
 56    /// Specifies how the thread safe connection should serialize writes. If provided
 57    /// the connection will call the write_queue_constructor for each database file in
 58    /// this process. The constructor is responsible for setting up a background thread or
 59    /// async task which handles queued writes with the provided connection.
 60    pub fn with_write_queue_constructor(
 61        mut self,
 62        write_queue_constructor: WriteQueueConstructor,
 63    ) -> Self {
 64        self.write_queue_constructor = Some(write_queue_constructor);
 65        self
 66    }
 67
 68    /// Queues an initialization query for the database file. This must be infallible
 69    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 70    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 71        self.db_initialize_query = Some(initialize_query);
 72        self
 73    }
 74
 75    pub async fn build(self) -> ThreadSafeConnection<M> {
 76        if !QUEUES.read().contains_key(&self.connection.uri) {
 77            let mut queues = QUEUES.write();
 78            if !queues.contains_key(&self.connection.uri) {
 79                let mut write_connection = self.connection.create_connection();
 80                // Enable writes for this connection
 81                write_connection.write = true;
 82                if let Some(mut write_queue_constructor) = self.write_queue_constructor {
 83                    let write_channel = write_queue_constructor(write_connection);
 84                    queues.insert(self.connection.uri.clone(), write_channel);
 85                } else {
 86                    use std::sync::mpsc::channel;
 87
 88                    let (sender, reciever) = channel::<QueuedWrite>();
 89                    thread::spawn(move || {
 90                        while let Ok(write) = reciever.recv() {
 91                            write(&write_connection)
 92                        }
 93                    });
 94
 95                    let sender = UnboundedSyncSender::new(sender);
 96                    queues.insert(
 97                        self.connection.uri.clone(),
 98                        Box::new(move |queued_write| {
 99                            sender
100                                .send(queued_write)
101                                .expect("Could not send write action to backgorund thread");
102                        }),
103                    );
104                }
105            }
106        }
107
108        let db_initialize_query = self.db_initialize_query;
109
110        self.connection
111            .write(move |connection| {
112                if let Some(db_initialize_query) = db_initialize_query {
113                    connection.exec(db_initialize_query).expect(&format!(
114                        "Db initialize query failed to execute: {}",
115                        db_initialize_query
116                    ))()
117                    .unwrap();
118                }
119
120                let mut failure_result = None;
121                for _ in 0..MIGRATION_RETRIES {
122                    failure_result = Some(M::migrate(connection));
123                    if failure_result.as_ref().unwrap().is_ok() {
124                        break;
125                    }
126                }
127
128                failure_result.unwrap().expect("Migration failed");
129            })
130            .await;
131
132        self.connection
133    }
134}
135
136impl<M: Migrator> ThreadSafeConnection<M> {
137    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
138        ThreadSafeConnectionBuilder::<M> {
139            db_initialize_query: None,
140            write_queue_constructor: None,
141            connection: Self {
142                uri: Arc::from(uri),
143                persistent,
144                connection_initialize_query: None,
145                connections: Default::default(),
146                _migrator: PhantomData,
147            },
148        }
149    }
150
151    /// Opens a new db connection with the initialized file path. This is internal and only
152    /// called from the deref function.
153    fn open_file(&self) -> Connection {
154        Connection::open_file(self.uri.as_ref())
155    }
156
157    /// Opens a shared memory connection using the file path as the identifier. This is internal
158    /// and only called from the deref function.
159    fn open_shared_memory(&self) -> Connection {
160        Connection::open_memory(Some(self.uri.as_ref()))
161    }
162
163    pub fn write<T: 'static + Send + Sync>(
164        &self,
165        callback: impl 'static + Send + FnOnce(&Connection) -> T,
166    ) -> impl Future<Output = T> {
167        let queues = QUEUES.read();
168        let write_channel = queues
169            .get(&self.uri)
170            .expect("Queues are inserted when build is called. This should always succeed");
171
172        // Create a one shot channel for the result of the queued write
173        // so we can await on the result
174        let (sender, reciever) = oneshot::channel();
175        write_channel(Box::new(move |connection| {
176            sender.send(callback(connection)).ok();
177        }));
178        reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
179    }
180
181    pub(crate) fn create_connection(&self) -> Connection {
182        let mut connection = if self.persistent {
183            self.open_file()
184        } else {
185            self.open_shared_memory()
186        };
187
188        // Disallow writes on the connection. The only writes allowed for thread safe connections
189        // are from the background thread that can serialize them.
190        connection.write = false;
191
192        if let Some(initialize_query) = self.connection_initialize_query {
193            connection.exec(initialize_query).expect(&format!(
194                "Initialize query failed to execute: {}",
195                initialize_query
196            ))()
197            .unwrap()
198        }
199
200        connection
201    }
202}
203
204impl ThreadSafeConnection<()> {
205    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
206    /// This allows construction to be infallible and not write to the db.
207    pub fn new(
208        uri: &str,
209        persistent: bool,
210        connection_initialize_query: Option<&'static str>,
211    ) -> Self {
212        Self {
213            uri: Arc::from(uri),
214            persistent,
215            connection_initialize_query,
216            connections: Default::default(),
217            _migrator: PhantomData,
218        }
219    }
220}
221
222impl<D: Domain> Clone for ThreadSafeConnection<D> {
223    fn clone(&self) -> Self {
224        Self {
225            uri: self.uri.clone(),
226            persistent: self.persistent,
227            connection_initialize_query: self.connection_initialize_query.clone(),
228            connections: self.connections.clone(),
229            _migrator: PhantomData,
230        }
231    }
232}
233
234// TODO:
235//  1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
236//  2. If the new db also fails, downgrade to a shared in memory db
237//  3. In either case notify the user about what went wrong
238impl<M: Migrator> Deref for ThreadSafeConnection<M> {
239    type Target = Connection;
240
241    fn deref(&self) -> &Self::Target {
242        self.connections.get_or(|| self.create_connection())
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use indoc::indoc;
249    use lazy_static::__Deref;
250    use std::thread;
251
252    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
253
254    #[test]
255    fn many_initialize_and_migrate_queries_at_once() {
256        let mut handles = vec![];
257
258        enum TestDomain {}
259        impl Domain for TestDomain {
260            fn name() -> &'static str {
261                "test"
262            }
263            fn migrations() -> &'static [&'static str] {
264                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
265            }
266        }
267
268        for _ in 0..100 {
269            handles.push(thread::spawn(|| {
270                let builder =
271                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
272                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
273                        .with_connection_initialize_query(indoc! {"
274                                PRAGMA synchronous=NORMAL;
275                                PRAGMA busy_timeout=1;
276                                PRAGMA foreign_keys=TRUE;
277                                PRAGMA case_sensitive_like=TRUE;
278                            "});
279                let _ = smol::block_on(builder.build()).deref();
280            }));
281        }
282
283        for handle in handles {
284            let _ = handle.join();
285        }
286    }
287
288    #[test]
289    #[should_panic]
290    fn wild_zed_lost_failure() {
291        enum TestWorkspace {}
292        impl Domain for TestWorkspace {
293            fn name() -> &'static str {
294                "workspace"
295            }
296
297            fn migrations() -> &'static [&'static str] {
298                &["
299                    CREATE TABLE workspaces(
300                        workspace_id INTEGER PRIMARY KEY,
301                        dock_visible INTEGER, -- Boolean
302                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
303                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
304                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
305                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
306                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
307                    ) STRICT;
308                    
309                    CREATE TABLE panes(
310                        pane_id INTEGER PRIMARY KEY,
311                        workspace_id INTEGER NOT NULL,
312                        active INTEGER NOT NULL, -- Boolean
313                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
314                            ON DELETE CASCADE 
315                            ON UPDATE CASCADE
316                    ) STRICT;
317                "]
318            }
319        }
320
321        let builder =
322            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
323                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
324
325        smol::block_on(builder.build());
326    }
327}