thread_safe_connection.rs

  1use std::{marker::PhantomData, ops::Deref, sync::Arc};
  2
  3use connection::Connection;
  4use thread_local::ThreadLocal;
  5
  6use crate::{
  7    connection,
  8    domain::{Domain, Migrator},
  9};
 10
 11pub struct ThreadSafeConnection<M: Migrator> {
 12    uri: Option<Arc<str>>,
 13    persistent: bool,
 14    initialize_query: Option<&'static str>,
 15    connection: Arc<ThreadLocal<Connection>>,
 16    _pd: PhantomData<M>,
 17}
 18
 19unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
 20unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
 21
 22impl<M: Migrator> ThreadSafeConnection<M> {
 23    pub fn new(uri: Option<&str>, persistent: bool) -> Self {
 24        if persistent == true && uri == None {
 25            // This panic is securing the unwrap in open_file(), don't remove it!
 26            panic!("Cannot create a persistent connection without a URI")
 27        }
 28        Self {
 29            uri: uri.map(|str| Arc::from(str)),
 30            persistent,
 31            initialize_query: None,
 32            connection: Default::default(),
 33            _pd: PhantomData,
 34        }
 35    }
 36
 37    /// Sets the query to run every time a connection is opened. This must
 38    /// be infallible (EG only use pragma statements)
 39    pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
 40        self.initialize_query = Some(initialize_query);
 41        self
 42    }
 43
 44    /// Opens a new db connection with the initialized file path. This is internal and only
 45    /// called from the deref function.
 46    /// If opening fails, the connection falls back to a shared memory connection
 47    fn open_file(&self) -> Connection {
 48        // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
 49        Connection::open_file(self.uri.as_ref().unwrap())
 50    }
 51
 52    /// Opens a shared memory connection using the file path as the identifier. This unwraps
 53    /// as we expect it always to succeed
 54    fn open_shared_memory(&self) -> Connection {
 55        Connection::open_memory(self.uri.as_ref().map(|str| str.deref()))
 56    }
 57
 58    // Open a new connection for the given domain, leaving this
 59    // connection intact.
 60    pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
 61        ThreadSafeConnection {
 62            uri: self.uri.clone(),
 63            persistent: self.persistent,
 64            initialize_query: self.initialize_query,
 65            connection: Default::default(),
 66            _pd: PhantomData,
 67        }
 68    }
 69}
 70
 71impl<D: Domain> Clone for ThreadSafeConnection<D> {
 72    fn clone(&self) -> Self {
 73        Self {
 74            uri: self.uri.clone(),
 75            persistent: self.persistent,
 76            initialize_query: self.initialize_query.clone(),
 77            connection: self.connection.clone(),
 78            _pd: PhantomData,
 79        }
 80    }
 81}
 82
 83// TODO:
 84//  1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
 85//  2. If the new db also fails, downgrade to a shared in memory db
 86//  3. In either case notify the user about what went wrong
 87impl<M: Migrator> Deref for ThreadSafeConnection<M> {
 88    type Target = Connection;
 89
 90    fn deref(&self) -> &Self::Target {
 91        self.connection.get_or(|| {
 92            let connection = if self.persistent {
 93                self.open_file()
 94            } else {
 95                self.open_shared_memory()
 96            };
 97
 98            if let Some(initialize_query) = self.initialize_query {
 99                connection.exec(initialize_query).expect(&format!(
100                    "Initialize query failed to execute: {}",
101                    initialize_query
102                ))()
103                .unwrap();
104            }
105
106            M::migrate(&connection).expect("Migrations failed");
107
108            connection
109        })
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use std::ops::Deref;
116
117    use crate::domain::Domain;
118
119    use super::ThreadSafeConnection;
120
121    #[test]
122    #[should_panic]
123    fn wild_zed_lost_failure() {
124        enum TestWorkspace {}
125        impl Domain for TestWorkspace {
126            fn name() -> &'static str {
127                "workspace"
128            }
129
130            fn migrations() -> &'static [&'static str] {
131                &["
132                    CREATE TABLE workspaces(
133                        workspace_id BLOB PRIMARY KEY,
134                        dock_visible INTEGER, -- Boolean
135                        dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
136                        dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
137                        timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
138                        FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
139                        FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
140                    ) STRICT;
141                    
142                    CREATE TABLE panes(
143                        pane_id INTEGER PRIMARY KEY,
144                        workspace_id BLOB NOT NULL,
145                        active INTEGER NOT NULL, -- Boolean
146                        FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id) 
147                            ON DELETE CASCADE 
148                            ON UPDATE CASCADE
149                    ) STRICT;
150                "]
151            }
152        }
153
154        let _ = ThreadSafeConnection::<TestWorkspace>::new(None, false)
155            .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
156            .deref();
157    }
158}