thread_safe_connection.rs

  1use futures::{Future, FutureExt};
  2use lazy_static::lazy_static;
  3use parking_lot::RwLock;
  4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
  5use thread_local::ThreadLocal;
  6
  7use crate::{
  8    connection::Connection,
  9    domain::{Domain, Migrator},
 10    util::UnboundedSyncSender,
 11};
 12
 13type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
 14
 15lazy_static! {
 16    static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
 17        Default::default();
 18}
 19
 20pub struct ThreadSafeConnection<M: Migrator> {
 21    uri: Arc<str>,
 22    persistent: bool,
 23    initialize_query: Option<&'static str>,
 24    connections: Arc<ThreadLocal<Connection>>,
 25    _migrator: PhantomData<M>,
 26}
 27
 28unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 29unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 30
 31impl<M: Migrator> ThreadSafeConnection<M> {
 32    pub fn new(uri: &str, persistent: bool) -> Self {
 33        Self {
 34            uri: Arc::from(uri),
 35            persistent,
 36            initialize_query: None,
 37            connections: Default::default(),
 38            _migrator: PhantomData,
 39        }
 40    }
 41
 42    /// Sets the query to run every time a connection is opened. This must
 43    /// be infallible (EG only use pragma statements)
 44    pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
 45        self.initialize_query = Some(initialize_query);
 46        self
 47    }
 48
 49    /// Opens a new db connection with the initialized file path. This is internal and only
 50    /// called from the deref function.
 51    /// If opening fails, the connection falls back to a shared memory connection
 52    fn open_file(&self) -> Connection {
 53        // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
 54        Connection::open_file(self.uri.as_ref())
 55    }
 56
 57    /// Opens a shared memory connection using the file path as the identifier. This unwraps
 58    /// as we expect it always to succeed
 59    fn open_shared_memory(&self) -> Connection {
 60        Connection::open_memory(Some(self.uri.as_ref()))
 61    }
 62
 63    // Open a new connection for the given domain, leaving this
 64    // connection intact.
 65    pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
 66        ThreadSafeConnection {
 67            uri: self.uri.clone(),
 68            persistent: self.persistent,
 69            initialize_query: self.initialize_query,
 70            connections: Default::default(),
 71            _migrator: PhantomData,
 72        }
 73    }
 74
 75    pub fn write<T: 'static + Send + Sync>(
 76        &self,
 77        callback: impl 'static + Send + FnOnce(&Connection) -> T,
 78    ) -> impl Future<Output = T> {
 79        // Startup write thread for this database if one hasn't already
 80        // been started and insert a channel to queue work for it
 81        if !QUEUES.read().contains_key(&self.uri) {
 82            use std::sync::mpsc::channel;
 83
 84            let (sender, reciever) = channel::<QueuedWrite>();
 85            let mut write_connection = self.create_connection();
 86            // Enable writes for this connection
 87            write_connection.write = true;
 88            thread::spawn(move || {
 89                while let Ok(write) = reciever.recv() {
 90                    write(&write_connection)
 91                }
 92            });
 93
 94            let mut queues = QUEUES.write();
 95            queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
 96        }
 97
 98        // Grab the queue for this database
 99        let queues = QUEUES.read();
100        let write_channel = queues.get(&self.uri).unwrap();
101
102        // Create a one shot channel for the result of the queued write
103        // so we can await on the result
104        let (sender, reciever) = futures::channel::oneshot::channel();
105        write_channel
106            .send(Box::new(move |connection| {
107                sender.send(callback(connection)).ok();
108            }))
109            .expect("Could not send write action to background thread");
110
111        reciever.map(|response| response.expect("Background thread unexpectedly closed"))
112    }
113
114    pub(crate) fn create_connection(&self) -> Connection {
115        let mut connection = if self.persistent {
116            self.open_file()
117        } else {
118            self.open_shared_memory()
119        };
120
121        // Enable writes for the migrations and initialization queries
122        connection.write = true;
123
124        if let Some(initialize_query) = self.initialize_query {
125            connection.exec(initialize_query).expect(&format!(
126                "Initialize query failed to execute: {}",
127                initialize_query
128            ))()
129            .unwrap();
130        }
131
132        M::migrate(&connection).expect("Migrations failed");
133
134        // Disable db writes for normal thread local connection
135        connection.write = false;
136        connection
137    }
138}
139
140impl<D: Domain> Clone for ThreadSafeConnection<D> {
141    fn clone(&self) -> Self {
142        Self {
143            uri: self.uri.clone(),
144            persistent: self.persistent,
145            initialize_query: self.initialize_query.clone(),
146            connections: self.connections.clone(),
147            _migrator: PhantomData,
148        }
149    }
150}
151
152// TODO:
153//  1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
154//  2. If the new db also fails, downgrade to a shared in memory db
155//  3. In either case notify the user about what went wrong
156impl<M: Migrator> Deref for ThreadSafeConnection<M> {
157    type Target = Connection;
158
159    fn deref(&self) -> &Self::Target {
160        self.connections.get_or(|| self.create_connection())
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use std::ops::Deref;
167
168    use crate::domain::Domain;
169
170    use super::ThreadSafeConnection;
171
172    #[test]
173    #[should_panic]
174    fn wild_zed_lost_failure() {
175        enum TestWorkspace {}
176        impl Domain for TestWorkspace {
177            fn name() -> &'static str {
178                "workspace"
179            }
180
181            fn migrations() -> &'static [&'static str] {
182                &["
183                    CREATE TABLE workspaces(
184                        workspace_id INTEGER PRIMARY KEY,
185                        dock_visible INTEGER, -- Boolean
186                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
187                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
188                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
189                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
190                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
191                    ) STRICT;
192                    
193                    CREATE TABLE panes(
194                        pane_id INTEGER PRIMARY KEY,
195                        workspace_id INTEGER NOT NULL,
196                        active INTEGER NOT NULL, -- Boolean
197                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
198                            ON DELETE CASCADE 
199                            ON UPDATE CASCADE
200                    ) STRICT;
201                "]
202            }
203        }
204
205        let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
206            .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
207            .deref();
208    }
209}