thread_safe_connection.rs

  1use anyhow::Context as _;
  2use collections::HashMap;
  3use futures::{Future, FutureExt, channel::oneshot};
  4use parking_lot::{Mutex, RwLock};
  5use std::{
  6    marker::PhantomData,
  7    ops::Deref,
  8    sync::{Arc, LazyLock},
  9    thread,
 10};
 11use thread_local::ThreadLocal;
 12
 13use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 14
 15const MIGRATION_RETRIES: usize = 10;
 16
 17type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 18type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 19type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 20
 21/// List of queues of tasks by database uri. This lets us serialize writes to the database
 22/// and have a single worker thread per db file. This means many thread safe connections
 23/// (possibly with different migrations) could all be communicating with the same background
 24/// thread.
 25static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
 26
 27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 29/// may be accessed by passing a callback to the `write` function which will queue the callback
 30#[derive(Clone)]
 31pub struct ThreadSafeConnection {
 32    uri: Arc<str>,
 33    persistent: bool,
 34    connection_initialize_query: Option<&'static str>,
 35    connections: Arc<ThreadLocal<Connection>>,
 36}
 37
 38unsafe impl Send for ThreadSafeConnection {}
 39unsafe impl Sync for ThreadSafeConnection {}
 40
 41pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 42    db_initialize_query: Option<&'static str>,
 43    write_queue_constructor: Option<WriteQueueConstructor>,
 44    connection: ThreadSafeConnection,
 45    _migrator: PhantomData<*mut M>,
 46}
 47
 48impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 49    /// Sets the query to run every time a connection is opened. This must
 50    /// be infallible (EG only use pragma statements) and not cause writes.
 51    /// to the db or it will panic.
 52    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 53        self.connection.connection_initialize_query = Some(initialize_query);
 54        self
 55    }
 56
 57    /// Queues an initialization query for the database file. This must be infallible
 58    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 59    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 60        self.db_initialize_query = Some(initialize_query);
 61        self
 62    }
 63
 64    /// Specifies how the thread safe connection should serialize writes. If provided
 65    /// the connection will call the write_queue_constructor for each database file in
 66    /// this process. The constructor is responsible for setting up a background thread or
 67    /// async task which handles queued writes with the provided connection.
 68    pub fn with_write_queue_constructor(
 69        mut self,
 70        write_queue_constructor: WriteQueueConstructor,
 71    ) -> Self {
 72        self.write_queue_constructor = Some(write_queue_constructor);
 73        self
 74    }
 75
 76    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection> {
 77        self.connection
 78            .initialize_queues(self.write_queue_constructor);
 79
 80        let db_initialize_query = self.db_initialize_query;
 81
 82        self.connection
 83            .write(move |connection| {
 84                if let Some(db_initialize_query) = db_initialize_query {
 85                    connection.exec(db_initialize_query).with_context(|| {
 86                        format!(
 87                            "Db initialize query failed to execute: {}",
 88                            db_initialize_query
 89                        )
 90                    })?()?;
 91                }
 92
 93                // Retry failed migrations in case they were run in parallel from different
 94                // processes. This gives a best attempt at migrating before bailing
 95                let mut migration_result =
 96                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 97
 98                let foreign_keys_enabled: bool =
 99                    connection.select_row::<i32>("PRAGMA foreign_keys")?()
100                        .unwrap_or(None)
101                        .map(|enabled| enabled != 0)
102                        .unwrap_or(false);
103
104                connection.exec("PRAGMA foreign_keys = OFF;")?()?;
105
106                for _ in 0..MIGRATION_RETRIES {
107                    migration_result = connection
108                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
109
110                    if migration_result.is_ok() {
111                        break;
112                    }
113                }
114
115                if foreign_keys_enabled {
116                    connection.exec("PRAGMA foreign_keys = ON;")?()?;
117                }
118                migration_result
119            })
120            .await?;
121
122        Ok(self.connection)
123    }
124}
125
126impl ThreadSafeConnection {
127    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
128        if !QUEUES.read().contains_key(&self.uri) {
129            let mut queues = QUEUES.write();
130            if !queues.contains_key(&self.uri) {
131                let mut write_queue_constructor =
132                    write_queue_constructor.unwrap_or_else(background_thread_queue);
133                queues.insert(self.uri.clone(), write_queue_constructor());
134                return true;
135            }
136        }
137        false
138    }
139
140    pub fn builder<M: Migrator>(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
141        ThreadSafeConnectionBuilder::<M> {
142            db_initialize_query: None,
143            write_queue_constructor: None,
144            connection: Self {
145                uri: Arc::from(uri),
146                persistent,
147                connection_initialize_query: None,
148                connections: Default::default(),
149            },
150            _migrator: PhantomData,
151        }
152    }
153
154    /// Opens a new db connection with the initialized file path. This is internal and only
155    /// called from the deref function.
156    fn open_file(uri: &str) -> Connection {
157        Connection::open_file(uri)
158    }
159
160    /// Opens a shared memory connection using the file path as the identifier. This is internal
161    /// and only called from the deref function.
162    fn open_shared_memory(uri: &str) -> Connection {
163        Connection::open_memory(Some(uri))
164    }
165
166    pub fn write<T: 'static + Send + Sync>(
167        &self,
168        callback: impl 'static + Send + FnOnce(&Connection) -> T,
169    ) -> impl Future<Output = T> {
170        // Check and invalidate queue and maybe recreate queue
171        let queues = QUEUES.read();
172        let write_channel = queues
173            .get(&self.uri)
174            .expect("Queues are inserted when build is called. This should always succeed");
175
176        // Create a one shot channel for the result of the queued write
177        // so we can await on the result
178        let (sender, receiver) = oneshot::channel();
179
180        let thread_safe_connection = (*self).clone();
181        write_channel(Box::new(move || {
182            let connection = thread_safe_connection.deref();
183            let result = connection.with_write(|connection| callback(connection));
184            sender.send(result).ok();
185        }));
186        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
187    }
188
189    pub(crate) fn create_connection(
190        persistent: bool,
191        uri: &str,
192        connection_initialize_query: Option<&'static str>,
193    ) -> Connection {
194        let mut connection = if persistent {
195            Self::open_file(uri)
196        } else {
197            Self::open_shared_memory(uri)
198        };
199
200        // Disallow writes on the connection. The only writes allowed for thread safe connections
201        // are from the background thread that can serialize them.
202        *connection.write.get_mut() = false;
203
204        if let Some(initialize_query) = connection_initialize_query {
205            connection.exec(initialize_query).unwrap_or_else(|_| {
206                panic!("Initialize query failed to execute: {}", initialize_query)
207            })()
208            .unwrap()
209        }
210
211        connection
212    }
213}
214
215impl ThreadSafeConnection {
216    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
217    /// This allows construction to be infallible and not write to the db.
218    pub fn new(
219        uri: &str,
220        persistent: bool,
221        connection_initialize_query: Option<&'static str>,
222        write_queue_constructor: Option<WriteQueueConstructor>,
223    ) -> Self {
224        let connection = Self {
225            uri: Arc::from(uri),
226            persistent,
227            connection_initialize_query,
228            connections: Default::default(),
229        };
230
231        connection.initialize_queues(write_queue_constructor);
232        connection
233    }
234}
235
236impl Deref for ThreadSafeConnection {
237    type Target = Connection;
238
239    fn deref(&self) -> &Self::Target {
240        self.connections.get_or(|| {
241            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
242        })
243    }
244}
245
246pub fn background_thread_queue() -> WriteQueueConstructor {
247    use std::sync::mpsc::channel;
248
249    Box::new(|| {
250        let (sender, receiver) = channel::<QueuedWrite>();
251
252        thread::Builder::new()
253            .name("sqlezWorker".to_string())
254            .spawn(move || {
255                while let Ok(write) = receiver.recv() {
256                    write()
257                }
258            })
259            .unwrap();
260
261        let sender = UnboundedSyncSender::new(sender);
262        Box::new(move |queued_write| {
263            sender
264                .send(queued_write)
265                .expect("Could not send write action to background thread");
266        })
267    })
268}
269
270pub fn locking_queue() -> WriteQueueConstructor {
271    Box::new(|| {
272        let write_mutex = Mutex::new(());
273        Box::new(move |queued_write| {
274            let _lock = write_mutex.lock();
275            queued_write();
276        })
277    })
278}
279
280#[cfg(test)]
281mod test {
282    use indoc::indoc;
283    use std::ops::Deref;
284
285    use std::thread;
286
287    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
288
289    #[test]
290    fn many_initialize_and_migrate_queries_at_once() {
291        let mut handles = vec![];
292
293        enum TestDomain {}
294        impl Domain for TestDomain {
295            const NAME: &str = "test";
296            const MIGRATIONS: &[&str] = &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"];
297        }
298
299        for _ in 0..100 {
300            handles.push(thread::spawn(|| {
301                let builder =
302                    ThreadSafeConnection::builder::<TestDomain>("annoying-test.db", false)
303                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
304                        .with_connection_initialize_query(indoc! {"
305                                PRAGMA synchronous=NORMAL;
306                                PRAGMA busy_timeout=1;
307                                PRAGMA foreign_keys=TRUE;
308                                PRAGMA case_sensitive_like=TRUE;
309                            "});
310
311                let _ = smol::block_on(builder.build()).unwrap().deref();
312            }));
313        }
314
315        for handle in handles {
316            let _ = handle.join();
317        }
318    }
319
320    #[test]
321    #[should_panic]
322    fn wild_zed_lost_failure() {
323        enum TestWorkspace {}
324        impl Domain for TestWorkspace {
325            const NAME: &str = "workspace";
326
327            const MIGRATIONS: &[&str] = &["
328                    CREATE TABLE workspaces(
329                        workspace_id INTEGER PRIMARY KEY,
330                        dock_visible INTEGER, -- Boolean
331                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
332                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
333                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
334                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
335                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
336                    ) STRICT;
337
338                    CREATE TABLE panes(
339                        pane_id INTEGER PRIMARY KEY,
340                        workspace_id INTEGER NOT NULL,
341                        active INTEGER NOT NULL, -- Boolean
342                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
343                            ON DELETE CASCADE
344                            ON UPDATE CASCADE
345                    ) STRICT;
346                "];
347        }
348
349        let builder =
350            ThreadSafeConnection::builder::<TestWorkspace>("wild_zed_lost_failure", false)
351                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
352
353        smol::block_on(builder.build()).unwrap();
354    }
355}