1use anyhow::Context;
2use futures::{channel::oneshot, Future, FutureExt};
3use lazy_static::lazy_static;
4use parking_lot::{Mutex, RwLock};
5use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
6use thread_local::ThreadLocal;
7
8use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
9
10const MIGRATION_RETRIES: usize = 10;
11
12type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
13type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
14type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
15lazy_static! {
16 /// List of queues of tasks by database uri. This lets us serialize writes to the database
17 /// and have a single worker thread per db file. This means many thread safe connections
18 /// (possibly with different migrations) could all be communicating with the same background
19 /// thread.
20 static ref QUEUES: RwLock<HashMap<Arc<str>, WriteQueue>> =
21 Default::default();
22}
23
24/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
25/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
26/// may be accessed by passing a callback to the `write` function which will queue the callback
27pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
28 uri: Arc<str>,
29 persistent: bool,
30 connection_initialize_query: Option<&'static str>,
31 connections: Arc<ThreadLocal<Connection>>,
32 _migrator: PhantomData<*mut M>,
33}
34
35unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
36unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
37
38pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
39 db_initialize_query: Option<&'static str>,
40 write_queue_constructor: Option<WriteQueueConstructor>,
41 connection: ThreadSafeConnection<M>,
42}
43
44impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
45 /// Sets the query to run every time a connection is opened. This must
46 /// be infallible (EG only use pragma statements) and not cause writes.
47 /// to the db or it will panic.
48 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
49 self.connection.connection_initialize_query = Some(initialize_query);
50 self
51 }
52
53 /// Queues an initialization query for the database file. This must be infallible
54 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
55 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
56 self.db_initialize_query = Some(initialize_query);
57 self
58 }
59
60 /// Specifies how the thread safe connection should serialize writes. If provided
61 /// the connection will call the write_queue_constructor for each database file in
62 /// this process. The constructor is responsible for setting up a background thread or
63 /// async task which handles queued writes with the provided connection.
64 pub fn with_write_queue_constructor(
65 mut self,
66 write_queue_constructor: WriteQueueConstructor,
67 ) -> Self {
68 self.write_queue_constructor = Some(write_queue_constructor);
69 self
70 }
71
72 pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
73 self.connection
74 .initialize_queues(self.write_queue_constructor);
75
76 let db_initialize_query = self.db_initialize_query;
77
78 self.connection
79 .write(move |connection| {
80 if let Some(db_initialize_query) = db_initialize_query {
81 connection.exec(db_initialize_query).with_context(|| {
82 format!(
83 "Db initialize query failed to execute: {}",
84 db_initialize_query
85 )
86 })?()?;
87 }
88
89 // Retry failed migrations in case they were run in parallel from different
90 // processes. This gives a best attempt at migrating before bailing
91 let mut migration_result =
92 anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
93
94 for _ in 0..MIGRATION_RETRIES {
95 migration_result = connection
96 .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
97
98 if migration_result.is_ok() {
99 break;
100 }
101 }
102
103 migration_result
104 })
105 .await?;
106
107 Ok(self.connection)
108 }
109}
110
111impl<M: Migrator> ThreadSafeConnection<M> {
112 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
113 if !QUEUES.read().contains_key(&self.uri) {
114 let mut queues = QUEUES.write();
115 if !queues.contains_key(&self.uri) {
116 let mut write_queue_constructor =
117 write_queue_constructor.unwrap_or_else(background_thread_queue);
118 queues.insert(self.uri.clone(), write_queue_constructor());
119 return true;
120 }
121 }
122 false
123 }
124
125 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
126 ThreadSafeConnectionBuilder::<M> {
127 db_initialize_query: None,
128 write_queue_constructor: None,
129 connection: Self {
130 uri: Arc::from(uri),
131 persistent,
132 connection_initialize_query: None,
133 connections: Default::default(),
134 _migrator: PhantomData,
135 },
136 }
137 }
138
139 /// Opens a new db connection with the initialized file path. This is internal and only
140 /// called from the deref function.
141 fn open_file(uri: &str) -> Connection {
142 Connection::open_file(uri)
143 }
144
145 /// Opens a shared memory connection using the file path as the identifier. This is internal
146 /// and only called from the deref function.
147 fn open_shared_memory(uri: &str) -> Connection {
148 Connection::open_memory(Some(uri))
149 }
150
151 pub fn write<T: 'static + Send + Sync>(
152 &self,
153 callback: impl 'static + Send + FnOnce(&Connection) -> T,
154 ) -> impl Future<Output = T> {
155 // Check and invalidate queue and maybe recreate queue
156 let queues = QUEUES.read();
157 let write_channel = queues
158 .get(&self.uri)
159 .expect("Queues are inserted when build is called. This should always succeed");
160
161 // Create a one shot channel for the result of the queued write
162 // so we can await on the result
163 let (sender, receiver) = oneshot::channel();
164
165 let thread_safe_connection = (*self).clone();
166 write_channel(Box::new(move || {
167 let connection = thread_safe_connection.deref();
168 let result = connection.with_write(|connection| callback(connection));
169 sender.send(result).ok();
170 }));
171 receiver.map(|response| response.expect("Write queue unexpectedly closed"))
172 }
173
174 pub(crate) fn create_connection(
175 persistent: bool,
176 uri: &str,
177 connection_initialize_query: Option<&'static str>,
178 ) -> Connection {
179 let mut connection = if persistent {
180 Self::open_file(uri)
181 } else {
182 Self::open_shared_memory(uri)
183 };
184
185 // Disallow writes on the connection. The only writes allowed for thread safe connections
186 // are from the background thread that can serialize them.
187 *connection.write.get_mut() = false;
188
189 if let Some(initialize_query) = connection_initialize_query {
190 connection.exec(initialize_query).unwrap_or_else(|_| {
191 panic!("Initialize query failed to execute: {}", initialize_query)
192 })()
193 .unwrap()
194 }
195
196 connection
197 }
198}
199
200impl ThreadSafeConnection<()> {
201 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
202 /// This allows construction to be infallible and not write to the db.
203 pub fn new(
204 uri: &str,
205 persistent: bool,
206 connection_initialize_query: Option<&'static str>,
207 write_queue_constructor: Option<WriteQueueConstructor>,
208 ) -> Self {
209 let connection = Self {
210 uri: Arc::from(uri),
211 persistent,
212 connection_initialize_query,
213 connections: Default::default(),
214 _migrator: PhantomData,
215 };
216
217 connection.initialize_queues(write_queue_constructor);
218 connection
219 }
220}
221
222impl<M: Migrator> Clone for ThreadSafeConnection<M> {
223 fn clone(&self) -> Self {
224 Self {
225 uri: self.uri.clone(),
226 persistent: self.persistent,
227 connection_initialize_query: self.connection_initialize_query,
228 connections: self.connections.clone(),
229 _migrator: PhantomData,
230 }
231 }
232}
233
234impl<M: Migrator> Deref for ThreadSafeConnection<M> {
235 type Target = Connection;
236
237 fn deref(&self) -> &Self::Target {
238 self.connections.get_or(|| {
239 Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
240 })
241 }
242}
243
244pub fn background_thread_queue() -> WriteQueueConstructor {
245 use std::sync::mpsc::channel;
246
247 Box::new(|| {
248 let (sender, receiver) = channel::<QueuedWrite>();
249
250 thread::spawn(move || {
251 while let Ok(write) = receiver.recv() {
252 write()
253 }
254 });
255
256 let sender = UnboundedSyncSender::new(sender);
257 Box::new(move |queued_write| {
258 sender
259 .send(queued_write)
260 .expect("Could not send write action to background thread");
261 })
262 })
263}
264
265pub fn locking_queue() -> WriteQueueConstructor {
266 Box::new(|| {
267 let write_mutex = Mutex::new(());
268 Box::new(move |queued_write| {
269 let _lock = write_mutex.lock();
270 queued_write();
271 })
272 })
273}
274
275#[cfg(test)]
276mod test {
277 use indoc::indoc;
278 use lazy_static::__Deref;
279
280 use std::thread;
281
282 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
283
284 #[test]
285 fn many_initialize_and_migrate_queries_at_once() {
286 let mut handles = vec![];
287
288 enum TestDomain {}
289 impl Domain for TestDomain {
290 fn name() -> &'static str {
291 "test"
292 }
293 fn migrations() -> &'static [&'static str] {
294 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
295 }
296 }
297
298 for _ in 0..100 {
299 handles.push(thread::spawn(|| {
300 let builder =
301 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
302 .with_db_initialization_query("PRAGMA journal_mode=WAL")
303 .with_connection_initialize_query(indoc! {"
304 PRAGMA synchronous=NORMAL;
305 PRAGMA busy_timeout=1;
306 PRAGMA foreign_keys=TRUE;
307 PRAGMA case_sensitive_like=TRUE;
308 "});
309
310 let _ = smol::block_on(builder.build()).unwrap().deref();
311 }));
312 }
313
314 for handle in handles {
315 let _ = handle.join();
316 }
317 }
318
319 #[test]
320 #[should_panic]
321 fn wild_zed_lost_failure() {
322 enum TestWorkspace {}
323 impl Domain for TestWorkspace {
324 fn name() -> &'static str {
325 "workspace"
326 }
327
328 fn migrations() -> &'static [&'static str] {
329 &["
330 CREATE TABLE workspaces(
331 workspace_id INTEGER PRIMARY KEY,
332 dock_visible INTEGER, -- Boolean
333 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
334 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
335 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
336 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
337 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
338 ) STRICT;
339
340 CREATE TABLE panes(
341 pane_id INTEGER PRIMARY KEY,
342 workspace_id INTEGER NOT NULL,
343 active INTEGER NOT NULL, -- Boolean
344 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
345 ON DELETE CASCADE
346 ON UPDATE CASCADE
347 ) STRICT;
348 "]
349 }
350 }
351
352 let builder =
353 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
354 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
355
356 smol::block_on(builder.build()).unwrap();
357 }
358}