1use std::{marker::PhantomData, ops::Deref, sync::Arc};
2
3use connection::Connection;
4use thread_local::ThreadLocal;
5
6use crate::{
7 connection,
8 domain::{Domain, Migrator},
9};
10
11pub struct ThreadSafeConnection<M: Migrator> {
12 uri: Option<Arc<str>>,
13 persistent: bool,
14 initialize_query: Option<&'static str>,
15 connection: Arc<ThreadLocal<Connection>>,
16 _pd: PhantomData<M>,
17}
18
19unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
20unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
21
22impl<M: Migrator> ThreadSafeConnection<M> {
23 pub fn new(uri: Option<&str>, persistent: bool) -> Self {
24 if persistent == true && uri == None {
25 // This panic is securing the unwrap in open_file(), don't remove it!
26 panic!("Cannot create a persistent connection without a URI")
27 }
28 Self {
29 uri: uri.map(|str| Arc::from(str)),
30 persistent,
31 initialize_query: None,
32 connection: Default::default(),
33 _pd: PhantomData,
34 }
35 }
36
37 /// Sets the query to run every time a connection is opened. This must
38 /// be infallible (EG only use pragma statements)
39 pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
40 self.initialize_query = Some(initialize_query);
41 self
42 }
43
44 /// Opens a new db connection with the initialized file path. This is internal and only
45 /// called from the deref function.
46 /// If opening fails, the connection falls back to a shared memory connection
47 fn open_file(&self) -> Connection {
48 // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
49 Connection::open_file(self.uri.as_ref().unwrap())
50 }
51
52 /// Opens a shared memory connection using the file path as the identifier. This unwraps
53 /// as we expect it always to succeed
54 fn open_shared_memory(&self) -> Connection {
55 Connection::open_memory(self.uri.as_ref().map(|str| str.deref()))
56 }
57
58 // Open a new connection for the given domain, leaving this
59 // connection intact.
60 pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
61 ThreadSafeConnection {
62 uri: self.uri.clone(),
63 persistent: self.persistent,
64 initialize_query: self.initialize_query,
65 connection: Default::default(),
66 _pd: PhantomData,
67 }
68 }
69}
70
71impl<D: Domain> Clone for ThreadSafeConnection<D> {
72 fn clone(&self) -> Self {
73 Self {
74 uri: self.uri.clone(),
75 persistent: self.persistent,
76 initialize_query: self.initialize_query.clone(),
77 connection: self.connection.clone(),
78 _pd: PhantomData,
79 }
80 }
81}
82
83// TODO:
84// 1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
85// 2. If the new db also fails, downgrade to a shared in memory db
86// 3. In either case notify the user about what went wrong
87impl<M: Migrator> Deref for ThreadSafeConnection<M> {
88 type Target = Connection;
89
90 fn deref(&self) -> &Self::Target {
91 self.connection.get_or(|| {
92 let connection = if self.persistent {
93 self.open_file()
94 } else {
95 self.open_shared_memory()
96 };
97
98 if let Some(initialize_query) = self.initialize_query {
99 connection.exec(initialize_query).expect(&format!(
100 "Initialize query failed to execute: {}",
101 initialize_query
102 ))()
103 .unwrap();
104 }
105
106 M::migrate(&connection).expect("Migrations failed");
107
108 connection
109 })
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use std::ops::Deref;
116
117 use crate::domain::Domain;
118
119 use super::ThreadSafeConnection;
120
121 #[test]
122 #[should_panic]
123 fn wild_zed_lost_failure() {
124 enum TestWorkspace {}
125 impl Domain for TestWorkspace {
126 fn name() -> &'static str {
127 "workspace"
128 }
129
130 fn migrations() -> &'static [&'static str] {
131 &["
132 CREATE TABLE workspaces(
133 workspace_id BLOB PRIMARY KEY,
134 dock_visible INTEGER, -- Boolean
135 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
136 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
137 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
138 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
139 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
140 ) STRICT;
141
142 CREATE TABLE panes(
143 pane_id INTEGER PRIMARY KEY,
144 workspace_id BLOB NOT NULL,
145 active INTEGER NOT NULL, -- Boolean
146 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
147 ON DELETE CASCADE
148 ON UPDATE CASCADE
149 ) STRICT;
150 "]
151 }
152 }
153
154 let _ = ThreadSafeConnection::<TestWorkspace>::new(None, false)
155 .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
156 .deref();
157 }
158}