thread_safe_connection.rs

  1use anyhow::Context;
  2use futures::{channel::oneshot, Future, FutureExt};
  3use lazy_static::lazy_static;
  4use parking_lot::{Mutex, RwLock};
  5use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  6use thread_local::ThreadLocal;
  7
  8use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
  9
 10const MIGRATION_RETRIES: usize = 10;
 11
 12type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 13type WriteQueueConstructor =
 14    Box<dyn 'static + Send + FnMut() -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
 15lazy_static! {
 16    /// List of queues of tasks by database uri. This lets us serialize writes to the database
 17    /// and have a single worker thread per db file. This means many thread safe connections
 18    /// (possibly with different migrations) could all be communicating with the same background
 19    /// thread.
 20    static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
 21        Default::default();
 22}
 23
 24/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 25/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 26/// may be accessed by passing a callback to the `write` function which will queue the callback
 27pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
 28    uri: Arc<str>,
 29    persistent: bool,
 30    connection_initialize_query: Option<&'static str>,
 31    connections: Arc<ThreadLocal<Connection>>,
 32    _migrator: PhantomData<*mut M>,
 33}
 34
 35unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
 36unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
 37
 38pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 39    db_initialize_query: Option<&'static str>,
 40    write_queue_constructor: Option<WriteQueueConstructor>,
 41    connection: ThreadSafeConnection<M>,
 42}
 43
 44impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 45    /// Sets the query to run every time a connection is opened. This must
 46    /// be infallible (EG only use pragma statements) and not cause writes.
 47    /// to the db or it will panic.
 48    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 49        self.connection.connection_initialize_query = Some(initialize_query);
 50        self
 51    }
 52
 53    /// Queues an initialization query for the database file. This must be infallible
 54    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 55    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 56        self.db_initialize_query = Some(initialize_query);
 57        self
 58    }
 59
 60    /// Specifies how the thread safe connection should serialize writes. If provided
 61    /// the connection will call the write_queue_constructor for each database file in
 62    /// this process. The constructor is responsible for setting up a background thread or
 63    /// async task which handles queued writes with the provided connection.
 64    pub fn with_write_queue_constructor(
 65        mut self,
 66        write_queue_constructor: WriteQueueConstructor,
 67    ) -> Self {
 68        self.write_queue_constructor = Some(write_queue_constructor);
 69        self
 70    }
 71
 72    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
 73        self.connection
 74            .initialize_queues(self.write_queue_constructor);
 75
 76        let db_initialize_query = self.db_initialize_query;
 77
 78        self.connection
 79            .write(move |connection| {
 80                if let Some(db_initialize_query) = db_initialize_query {
 81                    connection.exec(db_initialize_query).with_context(|| {
 82                        format!(
 83                            "Db initialize query failed to execute: {}",
 84                            db_initialize_query
 85                        )
 86                    })?()?;
 87                }
 88
 89                // Retry failed migrations in case they were run in parallel from different
 90                // processes. This gives a best attempt at migrating before bailing
 91                let mut migration_result =
 92                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 93
 94                for _ in 0..MIGRATION_RETRIES {
 95                    migration_result = connection
 96                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
 97
 98                    if migration_result.is_ok() {
 99                        println!("Migration succeded");
100                        break;
101                    }
102                }
103
104                migration_result
105            })
106            .await?;
107
108        Ok(self.connection)
109    }
110}
111
112impl<M: Migrator> ThreadSafeConnection<M> {
113    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
114        if !QUEUES.read().contains_key(&self.uri) {
115            let mut queues = QUEUES.write();
116            if !queues.contains_key(&self.uri) {
117                let mut write_queue_constructor =
118                    write_queue_constructor.unwrap_or(background_thread_queue());
119                queues.insert(self.uri.clone(), write_queue_constructor());
120                return true;
121            }
122        }
123        return false;
124    }
125
126    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
127        ThreadSafeConnectionBuilder::<M> {
128            db_initialize_query: None,
129            write_queue_constructor: None,
130            connection: Self {
131                uri: Arc::from(uri),
132                persistent,
133                connection_initialize_query: None,
134                connections: Default::default(),
135                _migrator: PhantomData,
136            },
137        }
138    }
139
140    /// Opens a new db connection with the initialized file path. This is internal and only
141    /// called from the deref function.
142    fn open_file(uri: &str) -> Connection {
143        Connection::open_file(uri)
144    }
145
146    /// Opens a shared memory connection using the file path as the identifier. This is internal
147    /// and only called from the deref function.
148    fn open_shared_memory(uri: &str) -> Connection {
149        Connection::open_memory(Some(uri))
150    }
151
152    pub fn write<T: 'static + Send + Sync>(
153        &self,
154        callback: impl 'static + Send + FnOnce(&Connection) -> T,
155    ) -> impl Future<Output = T> {
156        // Check and invalidate queue and maybe recreate queue
157        let queues = QUEUES.read();
158        let write_channel = queues
159            .get(&self.uri)
160            .expect("Queues are inserted when build is called. This should always succeed");
161
162        // Create a one shot channel for the result of the queued write
163        // so we can await on the result
164        let (sender, reciever) = oneshot::channel();
165
166        let thread_safe_connection = (*self).clone();
167        write_channel(Box::new(move || {
168            let connection = thread_safe_connection.deref();
169            let result = connection.with_write(|connection| callback(connection));
170            sender.send(result).ok();
171        }));
172        reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
173    }
174
175    pub(crate) fn create_connection(
176        persistent: bool,
177        uri: &str,
178        connection_initialize_query: Option<&'static str>,
179    ) -> Connection {
180        let mut connection = if persistent {
181            Self::open_file(uri)
182        } else {
183            Self::open_shared_memory(uri)
184        };
185
186        // Disallow writes on the connection. The only writes allowed for thread safe connections
187        // are from the background thread that can serialize them.
188        *connection.write.get_mut() = false;
189
190        if let Some(initialize_query) = connection_initialize_query {
191            connection.exec(initialize_query).expect(&format!(
192                "Initialize query failed to execute: {}",
193                initialize_query
194            ))()
195            .unwrap()
196        }
197
198        connection
199    }
200}
201
202impl ThreadSafeConnection<()> {
203    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
204    /// This allows construction to be infallible and not write to the db.
205    pub fn new(
206        uri: &str,
207        persistent: bool,
208        connection_initialize_query: Option<&'static str>,
209        write_queue_constructor: Option<WriteQueueConstructor>,
210    ) -> Self {
211        let connection = Self {
212            uri: Arc::from(uri),
213            persistent,
214            connection_initialize_query,
215            connections: Default::default(),
216            _migrator: PhantomData,
217        };
218
219        connection.initialize_queues(write_queue_constructor);
220        connection
221    }
222}
223
224impl<M: Migrator> Clone for ThreadSafeConnection<M> {
225    fn clone(&self) -> Self {
226        Self {
227            uri: self.uri.clone(),
228            persistent: self.persistent,
229            connection_initialize_query: self.connection_initialize_query.clone(),
230            connections: self.connections.clone(),
231            _migrator: PhantomData,
232        }
233    }
234}
235
236impl<M: Migrator> Deref for ThreadSafeConnection<M> {
237    type Target = Connection;
238
239    fn deref(&self) -> &Self::Target {
240        self.connections.get_or(|| {
241            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
242        })
243    }
244}
245
246pub fn background_thread_queue() -> WriteQueueConstructor {
247    use std::sync::mpsc::channel;
248
249    Box::new(|| {
250        let (sender, reciever) = channel::<QueuedWrite>();
251
252        thread::spawn(move || {
253            while let Ok(write) = reciever.recv() {
254                write()
255            }
256        });
257
258        let sender = UnboundedSyncSender::new(sender);
259        Box::new(move |queued_write| {
260            sender
261                .send(queued_write)
262                .expect("Could not send write action to background thread");
263        })
264    })
265}
266
267pub fn locking_queue() -> WriteQueueConstructor {
268    Box::new(|| {
269        let write_mutex = Mutex::new(());
270        Box::new(move |queued_write| {
271            let _lock = write_mutex.lock();
272            queued_write();
273        })
274    })
275}
276
277#[cfg(test)]
278mod test {
279    use indoc::indoc;
280    use lazy_static::__Deref;
281
282    use std::thread;
283
284    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
285
286    #[test]
287    fn many_initialize_and_migrate_queries_at_once() {
288        let mut handles = vec![];
289
290        enum TestDomain {}
291        impl Domain for TestDomain {
292            fn name() -> &'static str {
293                "test"
294            }
295            fn migrations() -> &'static [&'static str] {
296                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
297            }
298        }
299
300        for _ in 0..100 {
301            handles.push(thread::spawn(|| {
302                let builder =
303                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
304                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
305                        .with_connection_initialize_query(indoc! {"
306                                PRAGMA synchronous=NORMAL;
307                                PRAGMA busy_timeout=1;
308                                PRAGMA foreign_keys=TRUE;
309                                PRAGMA case_sensitive_like=TRUE;
310                            "});
311
312                let _ = smol::block_on(builder.build()).unwrap().deref();
313            }));
314        }
315
316        for handle in handles {
317            let _ = handle.join();
318        }
319    }
320
321    #[test]
322    #[should_panic]
323    fn wild_zed_lost_failure() {
324        enum TestWorkspace {}
325        impl Domain for TestWorkspace {
326            fn name() -> &'static str {
327                "workspace"
328            }
329
330            fn migrations() -> &'static [&'static str] {
331                &["
332                    CREATE TABLE workspaces(
333                        workspace_id INTEGER PRIMARY KEY,
334                        dock_visible INTEGER, -- Boolean
335                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
336                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
337                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
338                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
339                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
340                    ) STRICT;
341                    
342                    CREATE TABLE panes(
343                        pane_id INTEGER PRIMARY KEY,
344                        workspace_id INTEGER NOT NULL,
345                        active INTEGER NOT NULL, -- Boolean
346                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
347                            ON DELETE CASCADE 
348                            ON UPDATE CASCADE
349                    ) STRICT;
350                "]
351            }
352        }
353
354        let builder =
355            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
356                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
357
358        smol::block_on(builder.build()).unwrap();
359    }
360}