1use anyhow::Context;
2use collections::HashMap;
3use futures::{channel::oneshot, Future, FutureExt};
4use lazy_static::lazy_static;
5use parking_lot::{Mutex, RwLock};
6use std::{marker::PhantomData, ops::Deref, sync::Arc, thread};
7use thread_local::ThreadLocal;
8
9use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
10
11const MIGRATION_RETRIES: usize = 10;
12
13type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
14type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
15type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
16lazy_static! {
17 /// List of queues of tasks by database uri. This lets us serialize writes to the database
18 /// and have a single worker thread per db file. This means many thread safe connections
19 /// (possibly with different migrations) could all be communicating with the same background
20 /// thread.
21 static ref QUEUES: RwLock<HashMap<Arc<str>, WriteQueue>> =
22 Default::default();
23}
24
25/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
26/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
27/// may be accessed by passing a callback to the `write` function which will queue the callback
28pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
29 uri: Arc<str>,
30 persistent: bool,
31 connection_initialize_query: Option<&'static str>,
32 connections: Arc<ThreadLocal<Connection>>,
33 _migrator: PhantomData<*mut M>,
34}
35
36unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
37unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
38
39pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
40 db_initialize_query: Option<&'static str>,
41 write_queue_constructor: Option<WriteQueueConstructor>,
42 connection: ThreadSafeConnection<M>,
43}
44
45impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
46 /// Sets the query to run every time a connection is opened. This must
47 /// be infallible (EG only use pragma statements) and not cause writes.
48 /// to the db or it will panic.
49 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
50 self.connection.connection_initialize_query = Some(initialize_query);
51 self
52 }
53
54 /// Queues an initialization query for the database file. This must be infallible
55 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
56 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
57 self.db_initialize_query = Some(initialize_query);
58 self
59 }
60
61 /// Specifies how the thread safe connection should serialize writes. If provided
62 /// the connection will call the write_queue_constructor for each database file in
63 /// this process. The constructor is responsible for setting up a background thread or
64 /// async task which handles queued writes with the provided connection.
65 pub fn with_write_queue_constructor(
66 mut self,
67 write_queue_constructor: WriteQueueConstructor,
68 ) -> Self {
69 self.write_queue_constructor = Some(write_queue_constructor);
70 self
71 }
72
73 pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
74 self.connection
75 .initialize_queues(self.write_queue_constructor);
76
77 let db_initialize_query = self.db_initialize_query;
78
79 self.connection
80 .write(move |connection| {
81 if let Some(db_initialize_query) = db_initialize_query {
82 connection.exec(db_initialize_query).with_context(|| {
83 format!(
84 "Db initialize query failed to execute: {}",
85 db_initialize_query
86 )
87 })?()?;
88 }
89
90 // Retry failed migrations in case they were run in parallel from different
91 // processes. This gives a best attempt at migrating before bailing
92 let mut migration_result =
93 anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
94
95 for _ in 0..MIGRATION_RETRIES {
96 migration_result = connection
97 .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
98
99 if migration_result.is_ok() {
100 break;
101 }
102 }
103
104 migration_result
105 })
106 .await?;
107
108 Ok(self.connection)
109 }
110}
111
112impl<M: Migrator> ThreadSafeConnection<M> {
113 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
114 if !QUEUES.read().contains_key(&self.uri) {
115 let mut queues = QUEUES.write();
116 if !queues.contains_key(&self.uri) {
117 let mut write_queue_constructor =
118 write_queue_constructor.unwrap_or_else(background_thread_queue);
119 queues.insert(self.uri.clone(), write_queue_constructor());
120 return true;
121 }
122 }
123 false
124 }
125
126 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
127 ThreadSafeConnectionBuilder::<M> {
128 db_initialize_query: None,
129 write_queue_constructor: None,
130 connection: Self {
131 uri: Arc::from(uri),
132 persistent,
133 connection_initialize_query: None,
134 connections: Default::default(),
135 _migrator: PhantomData,
136 },
137 }
138 }
139
140 /// Opens a new db connection with the initialized file path. This is internal and only
141 /// called from the deref function.
142 fn open_file(uri: &str) -> Connection {
143 Connection::open_file(uri)
144 }
145
146 /// Opens a shared memory connection using the file path as the identifier. This is internal
147 /// and only called from the deref function.
148 fn open_shared_memory(uri: &str) -> Connection {
149 Connection::open_memory(Some(uri))
150 }
151
152 pub fn write<T: 'static + Send + Sync>(
153 &self,
154 callback: impl 'static + Send + FnOnce(&Connection) -> T,
155 ) -> impl Future<Output = T> {
156 // Check and invalidate queue and maybe recreate queue
157 let queues = QUEUES.read();
158 let write_channel = queues
159 .get(&self.uri)
160 .expect("Queues are inserted when build is called. This should always succeed");
161
162 // Create a one shot channel for the result of the queued write
163 // so we can await on the result
164 let (sender, receiver) = oneshot::channel();
165
166 let thread_safe_connection = (*self).clone();
167 write_channel(Box::new(move || {
168 let connection = thread_safe_connection.deref();
169 let result = connection.with_write(|connection| callback(connection));
170 sender.send(result).ok();
171 }));
172 receiver.map(|response| response.expect("Write queue unexpectedly closed"))
173 }
174
175 pub(crate) fn create_connection(
176 persistent: bool,
177 uri: &str,
178 connection_initialize_query: Option<&'static str>,
179 ) -> Connection {
180 let mut connection = if persistent {
181 Self::open_file(uri)
182 } else {
183 Self::open_shared_memory(uri)
184 };
185
186 // Disallow writes on the connection. The only writes allowed for thread safe connections
187 // are from the background thread that can serialize them.
188 *connection.write.get_mut() = false;
189
190 if let Some(initialize_query) = connection_initialize_query {
191 connection.exec(initialize_query).unwrap_or_else(|_| {
192 panic!("Initialize query failed to execute: {}", initialize_query)
193 })()
194 .unwrap()
195 }
196
197 connection
198 }
199}
200
201impl ThreadSafeConnection<()> {
202 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
203 /// This allows construction to be infallible and not write to the db.
204 pub fn new(
205 uri: &str,
206 persistent: bool,
207 connection_initialize_query: Option<&'static str>,
208 write_queue_constructor: Option<WriteQueueConstructor>,
209 ) -> Self {
210 let connection = Self {
211 uri: Arc::from(uri),
212 persistent,
213 connection_initialize_query,
214 connections: Default::default(),
215 _migrator: PhantomData,
216 };
217
218 connection.initialize_queues(write_queue_constructor);
219 connection
220 }
221}
222
223impl<M: Migrator> Clone for ThreadSafeConnection<M> {
224 fn clone(&self) -> Self {
225 Self {
226 uri: self.uri.clone(),
227 persistent: self.persistent,
228 connection_initialize_query: self.connection_initialize_query,
229 connections: self.connections.clone(),
230 _migrator: PhantomData,
231 }
232 }
233}
234
235impl<M: Migrator> Deref for ThreadSafeConnection<M> {
236 type Target = Connection;
237
238 fn deref(&self) -> &Self::Target {
239 self.connections.get_or(|| {
240 Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
241 })
242 }
243}
244
245pub fn background_thread_queue() -> WriteQueueConstructor {
246 use std::sync::mpsc::channel;
247
248 Box::new(|| {
249 let (sender, receiver) = channel::<QueuedWrite>();
250
251 thread::spawn(move || {
252 while let Ok(write) = receiver.recv() {
253 write()
254 }
255 });
256
257 let sender = UnboundedSyncSender::new(sender);
258 Box::new(move |queued_write| {
259 sender
260 .send(queued_write)
261 .expect("Could not send write action to background thread");
262 })
263 })
264}
265
266pub fn locking_queue() -> WriteQueueConstructor {
267 Box::new(|| {
268 let write_mutex = Mutex::new(());
269 Box::new(move |queued_write| {
270 let _lock = write_mutex.lock();
271 queued_write();
272 })
273 })
274}
275
276#[cfg(test)]
277mod test {
278 use indoc::indoc;
279 use lazy_static::__Deref;
280
281 use std::thread;
282
283 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
284
285 #[test]
286 fn many_initialize_and_migrate_queries_at_once() {
287 let mut handles = vec![];
288
289 enum TestDomain {}
290 impl Domain for TestDomain {
291 fn name() -> &'static str {
292 "test"
293 }
294 fn migrations() -> &'static [&'static str] {
295 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
296 }
297 }
298
299 for _ in 0..100 {
300 handles.push(thread::spawn(|| {
301 let builder =
302 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
303 .with_db_initialization_query("PRAGMA journal_mode=WAL")
304 .with_connection_initialize_query(indoc! {"
305 PRAGMA synchronous=NORMAL;
306 PRAGMA busy_timeout=1;
307 PRAGMA foreign_keys=TRUE;
308 PRAGMA case_sensitive_like=TRUE;
309 "});
310
311 let _ = smol::block_on(builder.build()).unwrap().deref();
312 }));
313 }
314
315 for handle in handles {
316 let _ = handle.join();
317 }
318 }
319
320 #[test]
321 #[should_panic]
322 fn wild_zed_lost_failure() {
323 enum TestWorkspace {}
324 impl Domain for TestWorkspace {
325 fn name() -> &'static str {
326 "workspace"
327 }
328
329 fn migrations() -> &'static [&'static str] {
330 &["
331 CREATE TABLE workspaces(
332 workspace_id INTEGER PRIMARY KEY,
333 dock_visible INTEGER, -- Boolean
334 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
335 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
336 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
337 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
338 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
339 ) STRICT;
340
341 CREATE TABLE panes(
342 pane_id INTEGER PRIMARY KEY,
343 workspace_id INTEGER NOT NULL,
344 active INTEGER NOT NULL, -- Boolean
345 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
346 ON DELETE CASCADE
347 ON UPDATE CASCADE
348 ) STRICT;
349 "]
350 }
351 }
352
353 let builder =
354 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
355 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
356
357 smol::block_on(builder.build()).unwrap();
358 }
359}