1use std::{ops::Deref, sync::Arc};
2
3use connection::Connection;
4use thread_local::ThreadLocal;
5
6use crate::{connection, migrations::Migration};
7
8pub struct ThreadSafeConnection {
9 uri: Arc<str>,
10 persistent: bool,
11 initialize_query: Option<&'static str>,
12 migrations: Option<&'static [Migration]>,
13 connection: Arc<ThreadLocal<Connection>>,
14}
15
16impl ThreadSafeConnection {
17 pub fn new(uri: &str, persistent: bool) -> Self {
18 Self {
19 uri: Arc::from(uri),
20 persistent,
21 initialize_query: None,
22 migrations: None,
23 connection: Default::default(),
24 }
25 }
26
27 /// Sets the query to run every time a connection is opened. This must
28 /// be infallible (EG only use pragma statements)
29 pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
30 self.initialize_query = Some(initialize_query);
31 self
32 }
33
34 /// Migrations have to be run per connection because we fallback to memory
35 /// so this needs
36 pub fn with_migrations(mut self, migrations: &'static [Migration]) -> Self {
37 self.migrations = Some(migrations);
38 self
39 }
40
41 /// Opens a new db connection with the initialized file path. This is internal and only
42 /// called from the deref function.
43 /// If opening fails, the connection falls back to a shared memory connection
44 fn open_file(&self) -> Connection {
45 Connection::open_file(self.uri.as_ref())
46 }
47
48 /// Opens a shared memory connection using the file path as the identifier. This unwraps
49 /// as we expect it always to succeed
50 fn open_shared_memory(&self) -> Connection {
51 Connection::open_memory(self.uri.as_ref())
52 }
53}
54
55impl Clone for ThreadSafeConnection {
56 fn clone(&self) -> Self {
57 Self {
58 uri: self.uri.clone(),
59 persistent: self.persistent,
60 initialize_query: self.initialize_query.clone(),
61 migrations: self.migrations.clone(),
62 connection: self.connection.clone(),
63 }
64 }
65}
66
67impl Deref for ThreadSafeConnection {
68 type Target = Connection;
69
70 fn deref(&self) -> &Self::Target {
71 self.connection.get_or(|| {
72 let connection = if self.persistent {
73 self.open_file()
74 } else {
75 self.open_shared_memory()
76 };
77
78 if let Some(initialize_query) = self.initialize_query {
79 connection.exec(initialize_query).expect(&format!(
80 "Initialize query failed to execute: {}",
81 initialize_query
82 ));
83 }
84
85 if let Some(migrations) = self.migrations {
86 for migration in migrations {
87 migration
88 .run(&connection)
89 .expect(&format!("Migrations failed to execute: {:?}", migration));
90 }
91 }
92
93 connection
94 })
95 }
96}