1use anyhow::Context as _;
2use collections::HashMap;
3use futures::{Future, FutureExt, channel::oneshot};
4use parking_lot::{Mutex, RwLock};
5use std::{
6 marker::PhantomData,
7 ops::Deref,
8 sync::{Arc, LazyLock},
9 thread,
10};
11use thread_local::ThreadLocal;
12
13use crate::{connection::Connection, domain::Migrator, util::UnboundedSyncSender};
14
15const MIGRATION_RETRIES: usize = 10;
16
17type QueuedWrite = Box<dyn 'static + Send + FnOnce()>;
18type WriteQueue = Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>;
19type WriteQueueConstructor = Box<dyn 'static + Send + FnMut() -> WriteQueue>;
20
21/// List of queues of tasks by database uri. This lets us serialize writes to the database
22/// and have a single worker thread per db file. This means many thread safe connections
23/// (possibly with different migrations) could all be communicating with the same background
24/// thread.
25static QUEUES: LazyLock<RwLock<HashMap<Arc<str>, WriteQueue>>> = LazyLock::new(Default::default);
26
27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
29/// may be accessed by passing a callback to the `write` function which will queue the callback
30pub struct ThreadSafeConnection<M: Migrator + 'static = ()> {
31 uri: Arc<str>,
32 persistent: bool,
33 connection_initialize_query: Option<&'static str>,
34 connections: Arc<ThreadLocal<Connection>>,
35 _migrator: PhantomData<*mut M>,
36}
37
38unsafe impl<M: Migrator> Send for ThreadSafeConnection<M> {}
39unsafe impl<M: Migrator> Sync for ThreadSafeConnection<M> {}
40
41pub struct ThreadSafeConnectionBuilder<M: Migrator + 'static = ()> {
42 db_initialize_query: Option<&'static str>,
43 write_queue_constructor: Option<WriteQueueConstructor>,
44 connection: ThreadSafeConnection<M>,
45}
46
47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
48 /// Sets the query to run every time a connection is opened. This must
49 /// be infallible (EG only use pragma statements) and not cause writes.
50 /// to the db or it will panic.
51 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
52 self.connection.connection_initialize_query = Some(initialize_query);
53 self
54 }
55
56 /// Queues an initialization query for the database file. This must be infallible
57 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
58 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
59 self.db_initialize_query = Some(initialize_query);
60 self
61 }
62
63 /// Specifies how the thread safe connection should serialize writes. If provided
64 /// the connection will call the write_queue_constructor for each database file in
65 /// this process. The constructor is responsible for setting up a background thread or
66 /// async task which handles queued writes with the provided connection.
67 pub fn with_write_queue_constructor(
68 mut self,
69 write_queue_constructor: WriteQueueConstructor,
70 ) -> Self {
71 self.write_queue_constructor = Some(write_queue_constructor);
72 self
73 }
74
75 pub async fn build(self) -> anyhow::Result<ThreadSafeConnection<M>> {
76 self.connection
77 .initialize_queues(self.write_queue_constructor);
78
79 let db_initialize_query = self.db_initialize_query;
80
81 self.connection
82 .write(move |connection| {
83 if let Some(db_initialize_query) = db_initialize_query {
84 connection.exec(db_initialize_query).with_context(|| {
85 format!(
86 "Db initialize query failed to execute: {}",
87 db_initialize_query
88 )
89 })?()?;
90 }
91
92 // Retry failed migrations in case they were run in parallel from different
93 // processes. This gives a best attempt at migrating before bailing
94 let mut migration_result =
95 anyhow::Result::<()>::Err(anyhow::anyhow!("Migration never run"));
96
97 for _ in 0..MIGRATION_RETRIES {
98 migration_result = connection
99 .with_savepoint("thread_safe_multi_migration", || M::migrate(connection));
100
101 if migration_result.is_ok() {
102 break;
103 }
104 }
105
106 migration_result
107 })
108 .await?;
109
110 Ok(self.connection)
111 }
112}
113
114impl<M: Migrator> ThreadSafeConnection<M> {
115 fn initialize_queues(&self, write_queue_constructor: Option<WriteQueueConstructor>) -> bool {
116 if !QUEUES.read().contains_key(&self.uri) {
117 let mut queues = QUEUES.write();
118 if !queues.contains_key(&self.uri) {
119 let mut write_queue_constructor =
120 write_queue_constructor.unwrap_or_else(background_thread_queue);
121 queues.insert(self.uri.clone(), write_queue_constructor());
122 return true;
123 }
124 }
125 false
126 }
127
128 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
129 ThreadSafeConnectionBuilder::<M> {
130 db_initialize_query: None,
131 write_queue_constructor: None,
132 connection: Self {
133 uri: Arc::from(uri),
134 persistent,
135 connection_initialize_query: None,
136 connections: Default::default(),
137 _migrator: PhantomData,
138 },
139 }
140 }
141
142 /// Opens a new db connection with the initialized file path. This is internal and only
143 /// called from the deref function.
144 fn open_file(uri: &str) -> Connection {
145 Connection::open_file(uri)
146 }
147
148 /// Opens a shared memory connection using the file path as the identifier. This is internal
149 /// and only called from the deref function.
150 fn open_shared_memory(uri: &str) -> Connection {
151 Connection::open_memory(Some(uri))
152 }
153
154 pub fn write<T: 'static + Send + Sync>(
155 &self,
156 callback: impl 'static + Send + FnOnce(&Connection) -> T,
157 ) -> impl Future<Output = T> {
158 // Check and invalidate queue and maybe recreate queue
159 let queues = QUEUES.read();
160 let write_channel = queues
161 .get(&self.uri)
162 .expect("Queues are inserted when build is called. This should always succeed");
163
164 // Create a one shot channel for the result of the queued write
165 // so we can await on the result
166 let (sender, receiver) = oneshot::channel();
167
168 let thread_safe_connection = (*self).clone();
169 write_channel(Box::new(move || {
170 let connection = thread_safe_connection.deref();
171 let result = connection.with_write(|connection| callback(connection));
172 sender.send(result).ok();
173 }));
174 receiver.map(|response| response.expect("Write queue unexpectedly closed"))
175 }
176
177 pub(crate) fn create_connection(
178 persistent: bool,
179 uri: &str,
180 connection_initialize_query: Option<&'static str>,
181 ) -> Connection {
182 let mut connection = if persistent {
183 Self::open_file(uri)
184 } else {
185 Self::open_shared_memory(uri)
186 };
187
188 // Disallow writes on the connection. The only writes allowed for thread safe connections
189 // are from the background thread that can serialize them.
190 *connection.write.get_mut() = false;
191
192 if let Some(initialize_query) = connection_initialize_query {
193 connection.exec(initialize_query).unwrap_or_else(|_| {
194 panic!("Initialize query failed to execute: {}", initialize_query)
195 })()
196 .unwrap()
197 }
198
199 connection
200 }
201}
202
203impl ThreadSafeConnection<()> {
204 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
205 /// This allows construction to be infallible and not write to the db.
206 pub fn new(
207 uri: &str,
208 persistent: bool,
209 connection_initialize_query: Option<&'static str>,
210 write_queue_constructor: Option<WriteQueueConstructor>,
211 ) -> Self {
212 let connection = Self {
213 uri: Arc::from(uri),
214 persistent,
215 connection_initialize_query,
216 connections: Default::default(),
217 _migrator: PhantomData,
218 };
219
220 connection.initialize_queues(write_queue_constructor);
221 connection
222 }
223}
224
225impl<M: Migrator> Clone for ThreadSafeConnection<M> {
226 fn clone(&self) -> Self {
227 Self {
228 uri: self.uri.clone(),
229 persistent: self.persistent,
230 connection_initialize_query: self.connection_initialize_query,
231 connections: self.connections.clone(),
232 _migrator: PhantomData,
233 }
234 }
235}
236
237impl<M: Migrator> Deref for ThreadSafeConnection<M> {
238 type Target = Connection;
239
240 fn deref(&self) -> &Self::Target {
241 self.connections.get_or(|| {
242 Self::create_connection(self.persistent, &self.uri, self.connection_initialize_query)
243 })
244 }
245}
246
247pub fn background_thread_queue() -> WriteQueueConstructor {
248 use std::sync::mpsc::channel;
249
250 Box::new(|| {
251 let (sender, receiver) = channel::<QueuedWrite>();
252
253 thread::spawn(move || {
254 while let Ok(write) = receiver.recv() {
255 write()
256 }
257 });
258
259 let sender = UnboundedSyncSender::new(sender);
260 Box::new(move |queued_write| {
261 sender
262 .send(queued_write)
263 .expect("Could not send write action to background thread");
264 })
265 })
266}
267
268pub fn locking_queue() -> WriteQueueConstructor {
269 Box::new(|| {
270 let write_mutex = Mutex::new(());
271 Box::new(move |queued_write| {
272 let _lock = write_mutex.lock();
273 queued_write();
274 })
275 })
276}
277
278#[cfg(test)]
279mod test {
280 use indoc::indoc;
281 use std::ops::Deref;
282
283 use std::thread;
284
285 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
286
287 #[test]
288 fn many_initialize_and_migrate_queries_at_once() {
289 let mut handles = vec![];
290
291 enum TestDomain {}
292 impl Domain for TestDomain {
293 fn name() -> &'static str {
294 "test"
295 }
296 fn migrations() -> &'static [&'static str] {
297 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
298 }
299 }
300
301 for _ in 0..100 {
302 handles.push(thread::spawn(|| {
303 let builder =
304 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
305 .with_db_initialization_query("PRAGMA journal_mode=WAL")
306 .with_connection_initialize_query(indoc! {"
307 PRAGMA synchronous=NORMAL;
308 PRAGMA busy_timeout=1;
309 PRAGMA foreign_keys=TRUE;
310 PRAGMA case_sensitive_like=TRUE;
311 "});
312
313 let _ = smol::block_on(builder.build()).unwrap().deref();
314 }));
315 }
316
317 for handle in handles {
318 let _ = handle.join();
319 }
320 }
321
322 #[test]
323 #[should_panic]
324 fn wild_zed_lost_failure() {
325 enum TestWorkspace {}
326 impl Domain for TestWorkspace {
327 fn name() -> &'static str {
328 "workspace"
329 }
330
331 fn migrations() -> &'static [&'static str] {
332 &["
333 CREATE TABLE workspaces(
334 workspace_id INTEGER PRIMARY KEY,
335 dock_visible INTEGER, -- Boolean
336 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
337 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
338 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
339 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
340 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
341 ) STRICT;
342
343 CREATE TABLE panes(
344 pane_id INTEGER PRIMARY KEY,
345 workspace_id INTEGER NOT NULL,
346 active INTEGER NOT NULL, -- Boolean
347 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
348 ON DELETE CASCADE
349 ON UPDATE CASCADE
350 ) STRICT;
351 "]
352 }
353 }
354
355 let builder =
356 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
357 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
358
359 smol::block_on(builder.build()).unwrap();
360 }
361}