thread_safe_connection.rs

 1use std::{ops::Deref, sync::Arc};
 2
 3use connection::Connection;
 4use thread_local::ThreadLocal;
 5
 6use crate::{connection, migrations::Migration};
 7
 8pub struct ThreadSafeConnection {
 9    uri: Arc<str>,
10    persistent: bool,
11    initialize_query: Option<&'static str>,
12    migrations: Option<&'static [Migration]>,
13    connection: Arc<ThreadLocal<Connection>>,
14}
15
16impl ThreadSafeConnection {
17    pub fn new(uri: &str, persistent: bool) -> Self {
18        Self {
19            uri: Arc::from(uri),
20            persistent,
21            initialize_query: None,
22            migrations: None,
23            connection: Default::default(),
24        }
25    }
26
27    /// Sets the query to run every time a connection is opened. This must
28    /// be infallible (EG only use pragma statements)
29    pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
30        self.initialize_query = Some(initialize_query);
31        self
32    }
33
34    /// Migrations have to be run per connection because we fallback to memory
35    /// so this needs
36    pub fn with_migrations(mut self, migrations: &'static [Migration]) -> Self {
37        self.migrations = Some(migrations);
38        self
39    }
40
41    /// Opens a new db connection with the initialized file path. This is internal and only
42    /// called from the deref function.
43    /// If opening fails, the connection falls back to a shared memory connection
44    fn open_file(&self) -> Connection {
45        Connection::open_file(self.uri.as_ref())
46    }
47
48    /// Opens a shared memory connection using the file path as the identifier. This unwraps
49    /// as we expect it always to succeed
50    fn open_shared_memory(&self) -> Connection {
51        Connection::open_memory(self.uri.as_ref())
52    }
53}
54
55impl Clone for ThreadSafeConnection {
56    fn clone(&self) -> Self {
57        Self {
58            uri: self.uri.clone(),
59            persistent: self.persistent,
60            initialize_query: self.initialize_query.clone(),
61            migrations: self.migrations.clone(),
62            connection: self.connection.clone(),
63        }
64    }
65}
66
67impl Deref for ThreadSafeConnection {
68    type Target = Connection;
69
70    fn deref(&self) -> &Self::Target {
71        self.connection.get_or(|| {
72            let connection = if self.persistent {
73                self.open_file()
74            } else {
75                self.open_shared_memory()
76            };
77
78            if let Some(initialize_query) = self.initialize_query {
79                connection.exec(initialize_query).expect(&format!(
80                    "Initialize query failed to execute: {}",
81                    initialize_query
82                ));
83            }
84
85            if let Some(migrations) = self.migrations {
86                for migration in migrations {
87                    migration
88                        .run(&connection)
89                        .expect(&format!("Migrations failed to execute: {:?}", migration));
90                }
91            }
92
93            connection
94        })
95    }
96}