thread_safe_connection.rs

  1use anyhow::Context as _;
  2use collections::HashMap;
  3use futures::{Future, FutureExt, channel::oneshot};
  4use parking_lot::{Mutex, RwLock};
  5use std::{
  6    marker::PhantomData,
  7    ops::Deref,
  8    sync::{Arc, LazyLock},
  9    thread,
 10};
 11use thread_local::ThreadLocal;
 12
 13use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
 14
 15const MIGRATION_RETRIES: usize = 10;
 16
 17type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
 18type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
 19type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
 20
 21/// List of queues of tasks by database uri. This lets us serialize writes to the database
 22/// and have a single worker thread per db file. This means many thread safe connections
 23/// (possibly with different migrations) could all be communicating with the same background
 24/// thread.
 25static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
 26
 27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
 28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
 29/// may be accessed by passing a callback to the `write` function which will queue the callback
 30pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
 31    uri: Arc<str>,
 32    persistent: bool,
 33    connection_initialize_query: Option<&'static str>,
 34    connections: Arc<ThreadLocal<Connection>>,
 35    _migrator: PhantomData<*mut M>,
 36}
 37
 38unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
 39unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
 40
 41pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
 42    db_initialize_query: Option<&'static str>,
 43    write_queue_constructor: Option<WriteQueueConstructor>,
 44    connection: ThreadSafeConnection<M>,
 45}
 46
 47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
 48    /// Sets the query to run every time a connection is opened. This must
 49    /// be infallible (EG only use pragma statements) and not cause writes.
 50    /// to the db or it will panic.
 51    pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
 52        self.connection.connection_initialize_query = Some(initialize_query);
 53        self
 54    }
 55
 56    /// Queues an initialization query for the database file. This must be infallible
 57    /// but may cause changes to the database file such as with `PRAGMA journal_mode`
 58    pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
 59        self.db_initialize_query = Some(initialize_query);
 60        self
 61    }
 62
 63    /// Specifies how the thread safe connection should serialize writes. If provided
 64    /// the connection will call the write_queue_constructor for each database file in
 65    /// this process. The constructor is responsible for setting up a background thread or
 66    /// async task which handles queued writes with the provided connection.
 67    pub fn with_write_queue_constructor(
 68        mut self,
 69        write_queue_constructor: WriteQueueConstructor,
 70    ) -> Self {
 71        self.write_queue_constructor = Some(write_queue_constructor);
 72        self
 73    }
 74
 75    pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
 76        self.connection
 77            .initialize_queues(self.write_queue_constructor);
 78
 79        let db_initialize_query = self.db_initialize_query;
 80
 81        self.connection
 82            .write(move |connection| {
 83                if let Some(db_initialize_query) = db_initialize_query {
 84                    connection.exec(db_initialize_query).with_context(|| {
 85                        format!(
 86                            "Db initialize query failed to execute: {}",
 87                            db_initialize_query
 88                        )
 89                    })?()?;
 90                }
 91
 92                // Retry failed migrations in case they were run in parallel from different
 93                // processes. This gives a best attempt at migrating before bailing
 94                let mut migration_result =
 95                    anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
 96
 97                for _ in 0..MIGRATION_RETRIES {
 98                    migration_result = connection
 99                        .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
100
101                    if migration_result.is_ok() {
102                        break;
103                    }
104                }
105
106                migration_result
107            })
108            .await?;
109
110        Ok(self.connection)
111    }
112}
113
114impl<M: Migrator> ThreadSafeConnection<M> {
115    fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
116        if !QUEUES.read().contains_key(&self.uri) {
117            let mut queues = QUEUES.write();
118            if !queues.contains_key(&self.uri) {
119                let mut write_queue_constructor =
120                    write_queue_constructor.unwrap_or_else(background_thread_queue);
121                queues.insert(self.uri.clone(), write_queue_constructor());
122                return true;
123            }
124        }
125        false
126    }
127
128    pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
129        ThreadSafeConnectionBuilder::<M> {
130            db_initialize_query: None,
131            write_queue_constructor: None,
132            connection: Self {
133                uri: Arc::from(uri),
134                persistent,
135                connection_initialize_query: None,
136                connections: Default::default(),
137                _migrator: PhantomData,
138            },
139        }
140    }
141
142    /// Opens a new db connection with the initialized file path. This is internal and only
143    /// called from the deref function.
144    fn open_file(uri: &str) -> Connection {
145        Connection::open_file(uri)
146    }
147
148    /// Opens a shared memory connection using the file path as the identifier. This is internal
149    /// and only called from the deref function.
150    fn open_shared_memory(uri: &str) -> Connection {
151        Connection::open_memory(Some(uri))
152    }
153
154    pub fn write<T: 'static + Send + Sync>(
155        &self,
156        callback: impl 'static + Send + FnOnce(&Connection) -> T,
157    ) -> impl Future<Output = T> {
158        // Check and invalidate queue and maybe recreate queue
159        let queues = QUEUES.read();
160        let write_channel = queues
161            .get(&self.uri)
162            .expect("Queues are inserted when build is called. This should always succeed");
163
164        // Create a one shot channel for the result of the queued write
165        // so we can await on the result
166        let (sender, receiver) = oneshot::channel();
167
168        let thread_safe_connection = (*self).clone();
169        write_channel(Box::new(move || {
170            let connection = thread_safe_connection.deref();
171            let result = connection.with_write(|connection| callback(connection));
172            sender.send(result).ok();
173        }));
174        receiver.map(|response| response.expect("Write queue unexpectedly closed"))
175    }
176
177    pub(crate) fn create_connection(
178        persistent: bool,
179        uri: &str,
180        connection_initialize_query: Option<&'static str>,
181    ) -> Connection {
182        let mut connection = if persistent {
183            Self::open_file(uri)
184        } else {
185            Self::open_shared_memory(uri)
186        };
187
188        // Disallow writes on the connection. The only writes allowed for thread safe connections
189        // are from the background thread that can serialize them.
190        *connection.write.get_mut() = false;
191
192        if let Some(initialize_query) = connection_initialize_query {
193            connection.exec(initialize_query).unwrap_or_else(|_| {
194                panic!("Initialize query failed to execute: {}", initialize_query)
195            })()
196            .unwrap()
197        }
198
199        connection
200    }
201}
202
203impl ThreadSafeConnection<()> {
204    /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
205    /// This allows construction to be infallible and not write to the db.
206    pub fn new(
207        uri: &str,
208        persistent: bool,
209        connection_initialize_query: Option<&'static str>,
210        write_queue_constructor: Option<WriteQueueConstructor>,
211    ) -> Self {
212        let connection = Self {
213            uri: Arc::from(uri),
214            persistent,
215            connection_initialize_query,
216            connections: Default::default(),
217            _migrator: PhantomData,
218        };
219
220        connection.initialize_queues(write_queue_constructor);
221        connection
222    }
223}
224
225impl<M: Migrator> Clone for ThreadSafeConnection<M> {
226    fn clone(&self) -> Self {
227        Self {
228            uri: self.uri.clone(),
229            persistent: self.persistent,
230            connection_initialize_query: self.connection_initialize_query,
231            connections: self.connections.clone(),
232            _migrator: PhantomData,
233        }
234    }
235}
236
237impl<M: Migrator> Deref for ThreadSafeConnection<M> {
238    type Target = Connection;
239
240    fn deref(&self) -> &Self::Target {
241        self.connections.get_or(|| {
242            Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
243        })
244    }
245}
246
247pub fn background_thread_queue() -> WriteQueueConstructor {
248    use std::sync::mpsc::channel;
249
250    Box::new(|| {
251        let (sender, receiver) = channel::<QueuedWrite>();
252
253        thread::spawn(move || {
254            while let Ok(write) = receiver.recv() {
255                write()
256            }
257        });
258
259        let sender = UnboundedSyncSender::new(sender);
260        Box::new(move |queued_write| {
261            sender
262                .send(queued_write)
263                .expect("Could not send write action to background thread");
264        })
265    })
266}
267
268pub fn locking_queue() -> WriteQueueConstructor {
269    Box::new(|| {
270        let write_mutex = Mutex::new(());
271        Box::new(move |queued_write| {
272            let _lock = write_mutex.lock();
273            queued_write();
274        })
275    })
276}
277
278#[cfg(test)]
279mod test {
280    use indoc::indoc;
281    use std::ops::Deref;
282
283    use std::thread;
284
285    use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
286
287    #[test]
288    fn many_initialize_and_migrate_queries_at_once() {
289        let mut handles = vec![];
290
291        enum TestDomain {}
292        impl Domain for TestDomain {
293            fn name() -> &'static str {
294                "test"
295            }
296            fn migrations() -> &'static [&'static str] {
297                &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
298            }
299        }
300
301        for _ in 0..100 {
302            handles.push(thread::spawn(|| {
303                let builder =
304                    ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
305                        .with_db_initialization_query("PRAGMA journal_mode=WAL")
306                        .with_connection_initialize_query(indoc! {"
307                                PRAGMA synchronous=NORMAL;
308                                PRAGMA busy_timeout=1;
309                                PRAGMA foreign_keys=TRUE;
310                                PRAGMA case_sensitive_like=TRUE;
311                            "});
312
313                let _ = smol::block_on(builder.build()).unwrap().deref();
314            }));
315        }
316
317        for handle in handles {
318            let _ = handle.join();
319        }
320    }
321
322    #[test]
323    #[should_panic]
324    fn wild_zed_lost_failure() {
325        enum TestWorkspace {}
326        impl Domain for TestWorkspace {
327            fn name() -> &'static str {
328                "workspace"
329            }
330
331            fn migrations() -> &'static [&'static str] {
332                &["
333                    CREATE TABLE workspaces(
334                        workspace_id INTEGER PRIMARY KEY,
335                        dock_visible INTEGER, -- Boolean
336                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
337                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
338                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
339                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
340                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
341                    ) STRICT;
342
343                    CREATE TABLE panes(
344                        pane_id INTEGER PRIMARY KEY,
345                        workspace_id INTEGER NOT NULL,
346                        active INTEGER NOT NULL, -- Boolean
347                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
348                            ON DELETE CASCADE
349                            ON UPDATE CASCADE
350                    ) STRICT;
351                "]
352            }
353        }
354
355        let builder =
356            ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
357                .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
358
359        smol::block_on(builder.build()).unwrap();
360    }
361}