1use futures::{channel::oneshot, Future, FutureExt};
2use lazy_static::lazy_static;
3use parking_lot::RwLock;
4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
5use thread_local::ThreadLocal;
6
7use crate::{
8 connection::Connection,
9 domain::{Domain, Migrator},
10 util::UnboundedSyncSender,
11};
12
13const MIGRATION_RETRIES: usize = 10;
14
15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
16lazy_static! {
17 /// List of queues of tasks by database uri. This lets us serialize writes to the database
18 /// and have a single worker thread per db file. This means many thread safe connections
19 /// (possibly with different migrations) could all be communicating with the same background
20 /// thread.
21 static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
22 Default::default();
23}
24
25/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
26/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
27/// may be accessed by passing a callback to the `write` function which will queue the callback
28pub struct ThreadSafeConnection<M: Migrator = ()> {
29 uri: Arc<str>,
30 persistent: bool,
31 connection_initialize_query: Option<&'static str>,
32 connections: Arc<ThreadLocal<Connection>>,
33 _migrator: PhantomData<M>,
34}
35
36unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
37unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
38
39pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
40 db_initialize_query: Option<&'static str>,
41 connection: ThreadSafeConnection<M>,
42}
43
44impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
45 /// Sets the query to run every time a connection is opened. This must
46 /// be infallible (EG only use pragma statements) and not cause writes.
47 /// to the db or it will panic.
48 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
49 self.connection.connection_initialize_query = Some(initialize_query);
50 self
51 }
52
53 /// Queues an initialization query for the database file. This must be infallible
54 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
55 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
56 self.db_initialize_query = Some(initialize_query);
57 self
58 }
59
60 pub async fn build(self) -> ThreadSafeConnection<M> {
61 let db_initialize_query = self.db_initialize_query;
62
63 self.connection
64 .write(move |connection| {
65 if let Some(db_initialize_query) = db_initialize_query {
66 connection.exec(db_initialize_query).expect(&format!(
67 "Db initialize query failed to execute: {}",
68 db_initialize_query
69 ))()
70 .unwrap();
71 }
72
73 let mut failure_result = None;
74 for _ in 0..MIGRATION_RETRIES {
75 failure_result = Some(M::migrate(connection));
76 if failure_result.as_ref().unwrap().is_ok() {
77 break;
78 }
79 }
80
81 failure_result.unwrap().expect("Migration failed");
82 })
83 .await;
84
85 self.connection
86 }
87}
88
89impl<M: Migrator> ThreadSafeConnection<M> {
90 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
91 ThreadSafeConnectionBuilder::<M> {
92 db_initialize_query: None,
93 connection: Self {
94 uri: Arc::from(uri),
95 persistent,
96 connection_initialize_query: None,
97 connections: Default::default(),
98 _migrator: PhantomData,
99 },
100 }
101 }
102
103 /// Opens a new db connection with the initialized file path. This is internal and only
104 /// called from the deref function.
105 fn open_file(&self) -> Connection {
106 Connection::open_file(self.uri.as_ref())
107 }
108
109 /// Opens a shared memory connection using the file path as the identifier. This is internal
110 /// and only called from the deref function.
111 fn open_shared_memory(&self) -> Connection {
112 Connection::open_memory(Some(self.uri.as_ref()))
113 }
114
115 fn queue_write_task(&self, callback: QueuedWrite) {
116 // Startup write thread for this database if one hasn't already
117 // been started and insert a channel to queue work for it
118 if !QUEUES.read().contains_key(&self.uri) {
119 let mut queues = QUEUES.write();
120 if !queues.contains_key(&self.uri) {
121 use std::sync::mpsc::channel;
122
123 let (sender, reciever) = channel::<QueuedWrite>();
124 let mut write_connection = self.create_connection();
125 // Enable writes for this connection
126 write_connection.write = true;
127 thread::spawn(move || {
128 while let Ok(write) = reciever.recv() {
129 write(&write_connection)
130 }
131 });
132
133 queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
134 }
135 }
136
137 // Grab the queue for this database
138 let queues = QUEUES.read();
139 let write_channel = queues.get(&self.uri).unwrap();
140
141 write_channel
142 .send(callback)
143 .expect("Could not send write action to backgorund thread");
144 }
145
146 pub fn write<T: 'static + Send + Sync>(
147 &self,
148 callback: impl 'static + Send + FnOnce(&Connection) -> T,
149 ) -> impl Future<Output = T> {
150 // Create a one shot channel for the result of the queued write
151 // so we can await on the result
152 let (sender, reciever) = oneshot::channel();
153 self.queue_write_task(Box::new(move |connection| {
154 sender.send(callback(connection)).ok();
155 }));
156
157 reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
158 }
159
160 pub(crate) fn create_connection(&self) -> Connection {
161 let mut connection = if self.persistent {
162 self.open_file()
163 } else {
164 self.open_shared_memory()
165 };
166
167 // Disallow writes on the connection. The only writes allowed for thread safe connections
168 // are from the background thread that can serialize them.
169 connection.write = false;
170
171 if let Some(initialize_query) = self.connection_initialize_query {
172 connection.exec(initialize_query).expect(&format!(
173 "Initialize query failed to execute: {}",
174 initialize_query
175 ))()
176 .unwrap()
177 }
178
179 connection
180 }
181}
182
183impl ThreadSafeConnection<()> {
184 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
185 /// This allows construction to be infallible and not write to the db.
186 pub fn new(
187 uri: &str,
188 persistent: bool,
189 connection_initialize_query: Option<&'static str>,
190 ) -> Self {
191 Self {
192 uri: Arc::from(uri),
193 persistent,
194 connection_initialize_query,
195 connections: Default::default(),
196 _migrator: PhantomData,
197 }
198 }
199}
200
201impl<D: Domain> Clone for ThreadSafeConnection<D> {
202 fn clone(&self) -> Self {
203 Self {
204 uri: self.uri.clone(),
205 persistent: self.persistent,
206 connection_initialize_query: self.connection_initialize_query.clone(),
207 connections: self.connections.clone(),
208 _migrator: PhantomData,
209 }
210 }
211}
212
213// TODO:
214// 1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
215// 2. If the new db also fails, downgrade to a shared in memory db
216// 3. In either case notify the user about what went wrong
217impl<M: Migrator> Deref for ThreadSafeConnection<M> {
218 type Target = Connection;
219
220 fn deref(&self) -> &Self::Target {
221 self.connections.get_or(|| self.create_connection())
222 }
223}
224
225#[cfg(test)]
226mod test {
227 use indoc::indoc;
228 use lazy_static::__Deref;
229 use std::thread;
230
231 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
232
233 #[test]
234 fn many_initialize_and_migrate_queries_at_once() {
235 let mut handles = vec![];
236
237 enum TestDomain {}
238 impl Domain for TestDomain {
239 fn name() -> &'static str {
240 "test"
241 }
242 fn migrations() -> &'static [&'static str] {
243 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
244 }
245 }
246
247 for _ in 0..100 {
248 handles.push(thread::spawn(|| {
249 let builder =
250 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
251 .with_db_initialization_query("PRAGMA journal_mode=WAL")
252 .with_connection_initialize_query(indoc! {"
253 PRAGMA synchronous=NORMAL;
254 PRAGMA busy_timeout=1;
255 PRAGMA foreign_keys=TRUE;
256 PRAGMA case_sensitive_like=TRUE;
257 "});
258 let _ = smol::block_on(builder.build()).deref();
259 }));
260 }
261
262 for handle in handles {
263 let _ = handle.join();
264 }
265 }
266
267 #[test]
268 #[should_panic]
269 fn wild_zed_lost_failure() {
270 enum TestWorkspace {}
271 impl Domain for TestWorkspace {
272 fn name() -> &'static str {
273 "workspace"
274 }
275
276 fn migrations() -> &'static [&'static str] {
277 &["
278 CREATE TABLE workspaces(
279 workspace_id INTEGER PRIMARY KEY,
280 dock_visible INTEGER, -- Boolean
281 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
282 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
283 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
284 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
285 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
286 ) STRICT;
287
288 CREATE TABLE panes(
289 pane_id INTEGER PRIMARY KEY,
290 workspace_id INTEGER NOT NULL,
291 active INTEGER NOT NULL, -- Boolean
292 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
293 ON DELETE CASCADE
294 ON UPDATE CASCADE
295 ) STRICT;
296 "]
297 }
298 }
299
300 let builder =
301 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
302 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
303
304 smol::block_on(builder.build());
305 }
306}