1use futures::{channel::oneshot, Future, FutureExt};
2use lazy_static::lazy_static;
3use parking_lot::{Mutex, RwLock};
4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
5use thread_local::ThreadLocal;
6
7use crate::{
8 connection::Connection,
9 domain::{Domain, Migrator},
10 util::UnboundedSyncSender,
11};
12
13const MIGRATION_RETRIES: usize = 10;
14
15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
16type WriteQueueConstructor =
17 Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
18lazy_static! {
19 /// List of queues of tasks by database uri. This lets us serialize writes to the database
20 /// and have a single worker thread per db file. This means many thread safe connections
21 /// (possibly with different migrations) could all be communicating with the same background
22 /// thread.
23 static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
24 Default::default();
25}
26
27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
29/// may be accessed by passing a callback to the `write` function which will queue the callback
30pub struct ThreadSafeConnection<M: Migrator = ()> {
31 uri: Arc<str>,
32 persistent: bool,
33 connection_initialize_query: Option<&'static str>,
34 connections: Arc<ThreadLocal<Connection>>,
35 _migrator: PhantomData<M>,
36}
37
38unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
39unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
40
41pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
42 db_initialize_query: Option<&'static str>,
43 write_queue_constructor: Option<WriteQueueConstructor>,
44 connection: ThreadSafeConnection<M>,
45}
46
47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
48 /// Sets the query to run every time a connection is opened. This must
49 /// be infallible (EG only use pragma statements) and not cause writes.
50 /// to the db or it will panic.
51 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
52 self.connection.connection_initialize_query = Some(initialize_query);
53 self
54 }
55
56 /// Specifies how the thread safe connection should serialize writes. If provided
57 /// the connection will call the write_queue_constructor for each database file in
58 /// this process. The constructor is responsible for setting up a background thread or
59 /// async task which handles queued writes with the provided connection.
60 pub fn with_write_queue_constructor(
61 mut self,
62 write_queue_constructor: WriteQueueConstructor,
63 ) -> Self {
64 self.write_queue_constructor = Some(write_queue_constructor);
65 self
66 }
67
68 /// Queues an initialization query for the database file. This must be infallible
69 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
70 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
71 self.db_initialize_query = Some(initialize_query);
72 self
73 }
74
75 pub async fn build(self) -> ThreadSafeConnection<M> {
76 self.connection
77 .initialize_queues(self.write_queue_constructor);
78
79 let db_initialize_query = self.db_initialize_query;
80
81 self.connection
82 .write(move |connection| {
83 if let Some(db_initialize_query) = db_initialize_query {
84 connection.exec(db_initialize_query).expect(&format!(
85 "Db initialize query failed to execute: {}",
86 db_initialize_query
87 ))()
88 .unwrap();
89 }
90
91 let mut failure_result = None;
92 for _ in 0..MIGRATION_RETRIES {
93 failure_result = Some(M::migrate(connection));
94 if failure_result.as_ref().unwrap().is_ok() {
95 break;
96 }
97 }
98
99 failure_result.unwrap().expect("Migration failed");
100 })
101 .await;
102
103 self.connection
104 }
105}
106
107impl<M: Migrator> ThreadSafeConnection<M> {
108 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) {
109 if !QUEUES.read().contains_key(&self.uri) {
110 let mut queues = QUEUES.write();
111 if !queues.contains_key(&self.uri) {
112 let mut write_connection = self.create_connection();
113 // Enable writes for this connection
114 write_connection.write = true;
115 if let Some(mut write_queue_constructor) = write_queue_constructor {
116 let write_channel = write_queue_constructor(write_connection);
117 queues.insert(self.uri.clone(), write_channel);
118 } else {
119 use std::sync::mpsc::channel;
120
121 let (sender, reciever) = channel::<QueuedWrite>();
122 thread::spawn(move || {
123 while let Ok(write) = reciever.recv() {
124 write(&write_connection)
125 }
126 });
127
128 let sender = UnboundedSyncSender::new(sender);
129 queues.insert(
130 self.uri.clone(),
131 Box::new(move |queued_write| {
132 sender
133 .send(queued_write)
134 .expect("Could not send write action to backgorund thread");
135 }),
136 );
137 }
138 }
139 }
140 }
141
142 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
143 ThreadSafeConnectionBuilder::<M> {
144 db_initialize_query: None,
145 write_queue_constructor: None,
146 connection: Self {
147 uri: Arc::from(uri),
148 persistent,
149 connection_initialize_query: None,
150 connections: Default::default(),
151 _migrator: PhantomData,
152 },
153 }
154 }
155
156 /// Opens a new db connection with the initialized file path. This is internal and only
157 /// called from the deref function.
158 fn open_file(&self) -> Connection {
159 Connection::open_file(self.uri.as_ref())
160 }
161
162 /// Opens a shared memory connection using the file path as the identifier. This is internal
163 /// and only called from the deref function.
164 fn open_shared_memory(&self) -> Connection {
165 Connection::open_memory(Some(self.uri.as_ref()))
166 }
167
168 pub fn write<T: 'static + Send + Sync>(
169 &self,
170 callback: impl 'static + Send + FnOnce(&Connection) -> T,
171 ) -> impl Future<Output = T> {
172 let queues = QUEUES.read();
173 let write_channel = queues
174 .get(&self.uri)
175 .expect("Queues are inserted when build is called. This should always succeed");
176
177 // Create a one shot channel for the result of the queued write
178 // so we can await on the result
179 let (sender, reciever) = oneshot::channel();
180 write_channel(Box::new(move |connection| {
181 sender.send(callback(connection)).ok();
182 }));
183 reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
184 }
185
186 pub(crate) fn create_connection(&self) -> Connection {
187 let mut connection = if self.persistent {
188 self.open_file()
189 } else {
190 self.open_shared_memory()
191 };
192
193 // Disallow writes on the connection. The only writes allowed for thread safe connections
194 // are from the background thread that can serialize them.
195 connection.write = false;
196
197 if let Some(initialize_query) = self.connection_initialize_query {
198 connection.exec(initialize_query).expect(&format!(
199 "Initialize query failed to execute: {}",
200 initialize_query
201 ))()
202 .unwrap()
203 }
204
205 connection
206 }
207}
208
209impl ThreadSafeConnection<()> {
210 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
211 /// This allows construction to be infallible and not write to the db.
212 pub fn new(
213 uri: &str,
214 persistent: bool,
215 connection_initialize_query: Option<&'static str>,
216 write_queue_constructor: Option<WriteQueueConstructor>,
217 ) -> Self {
218 let connection = Self {
219 uri: Arc::from(uri),
220 persistent,
221 connection_initialize_query,
222 connections: Default::default(),
223 _migrator: PhantomData,
224 };
225
226 connection.initialize_queues(write_queue_constructor);
227 connection
228 }
229}
230
231impl<D: Domain> Clone for ThreadSafeConnection<D> {
232 fn clone(&self) -> Self {
233 Self {
234 uri: self.uri.clone(),
235 persistent: self.persistent,
236 connection_initialize_query: self.connection_initialize_query.clone(),
237 connections: self.connections.clone(),
238 _migrator: PhantomData,
239 }
240 }
241}
242
243// TODO:
244// 1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
245// 2. If the new db also fails, downgrade to a shared in memory db
246// 3. In either case notify the user about what went wrong
247impl<M: Migrator> Deref for ThreadSafeConnection<M> {
248 type Target = Connection;
249
250 fn deref(&self) -> &Self::Target {
251 self.connections.get_or(|| self.create_connection())
252 }
253}
254
255pub fn locking_queue() -> WriteQueueConstructor {
256 Box::new(|connection| {
257 let connection = Mutex::new(connection);
258 Box::new(move |queued_write| {
259 let connection = connection.lock();
260 queued_write(&connection)
261 })
262 })
263}
264
265#[cfg(test)]
266mod test {
267 use indoc::indoc;
268 use lazy_static::__Deref;
269 use std::thread;
270
271 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
272
273 #[test]
274 fn many_initialize_and_migrate_queries_at_once() {
275 let mut handles = vec![];
276
277 enum TestDomain {}
278 impl Domain for TestDomain {
279 fn name() -> &'static str {
280 "test"
281 }
282 fn migrations() -> &'static [&'static str] {
283 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
284 }
285 }
286
287 for _ in 0..100 {
288 handles.push(thread::spawn(|| {
289 let builder =
290 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
291 .with_db_initialization_query("PRAGMA journal_mode=WAL")
292 .with_connection_initialize_query(indoc! {"
293 PRAGMA synchronous=NORMAL;
294 PRAGMA busy_timeout=1;
295 PRAGMA foreign_keys=TRUE;
296 PRAGMA case_sensitive_like=TRUE;
297 "});
298 let _ = smol::block_on(builder.build()).deref();
299 }));
300 }
301
302 for handle in handles {
303 let _ = handle.join();
304 }
305 }
306
307 #[test]
308 #[should_panic]
309 fn wild_zed_lost_failure() {
310 enum TestWorkspace {}
311 impl Domain for TestWorkspace {
312 fn name() -> &'static str {
313 "workspace"
314 }
315
316 fn migrations() -> &'static [&'static str] {
317 &["
318 CREATE TABLE workspaces(
319 workspace_id INTEGER PRIMARY KEY,
320 dock_visible INTEGER, -- Boolean
321 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
322 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
323 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
324 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
325 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
326 ) STRICT;
327
328 CREATE TABLE panes(
329 pane_id INTEGER PRIMARY KEY,
330 workspace_id INTEGER NOT NULL,
331 active INTEGER NOT NULL, -- Boolean
332 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
333 ON DELETE CASCADE
334 ON UPDATE CASCADE
335 ) STRICT;
336 "]
337 }
338 }
339
340 let builder =
341 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
342 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
343
344 smol::block_on(builder.build());
345 }
346}