1use futures::{channel::oneshot, Future, FutureExt};
2use lazy_static::lazy_static;
3use parking_lot::RwLock;
4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
5use thread_local::ThreadLocal;
6
7use crate::{
8 connection::Connection,
9 domain::{Domain, Migrator},
10 util::UnboundedSyncSender,
11};
12
13const MIGRATION_RETRIES: usize = 10;
14
15type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
16type WriteQueueConstructor =
17 Box<dyn 'static + Send + FnMut(Connection) -> Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>;
18lazy_static! {
19 /// List of queues of tasks by database uri. This lets us serialize writes to the database
20 /// and have a single worker thread per db file. This means many thread safe connections
21 /// (possibly with different migrations) could all be communicating with the same background
22 /// thread.
23 static ref QUEUES: RwLock<HashMap<Arc<str>, Box<dyn 'static + Send + Sync + Fn(QueuedWrite)>>> =
24 Default::default();
25}
26
27/// Thread safe connection to a given database file or in memory db. This can be cloned, shared, static,
28/// whatever. It derefs to a synchronous connection by thread that is read only. A write capable connection
29/// may be accessed by passing a callback to the `write` function which will queue the callback
30pub struct ThreadSafeConnection<M: Migrator = ()> {
31 uri: Arc<str>,
32 persistent: bool,
33 connection_initialize_query: Option<&'static str>,
34 connections: Arc<ThreadLocal<Connection>>,
35 _migrator: PhantomData<M>,
36}
37
38unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
39unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
40
41pub struct ThreadSafeConnectionBuilder<M: Migrator = ()> {
42 db_initialize_query: Option<&'static str>,
43 write_queue_constructor: Option<WriteQueueConstructor>,
44 connection: ThreadSafeConnection<M>,
45}
46
47impl<M: Migrator> ThreadSafeConnectionBuilder<M> {
48 /// Sets the query to run every time a connection is opened. This must
49 /// be infallible (EG only use pragma statements) and not cause writes.
50 /// to the db or it will panic.
51 pub fn with_connection_initialize_query(mut self, initialize_query: &'static str) -> Self {
52 self.connection.connection_initialize_query = Some(initialize_query);
53 self
54 }
55
56 /// Specifies how the thread safe connection should serialize writes. If provided
57 /// the connection will call the write_queue_constructor for each database file in
58 /// this process. The constructor is responsible for setting up a background thread or
59 /// async task which handles queued writes with the provided connection.
60 pub fn with_write_queue_constructor(
61 mut self,
62 write_queue_constructor: WriteQueueConstructor,
63 ) -> Self {
64 self.write_queue_constructor = Some(write_queue_constructor);
65 self
66 }
67
68 /// Queues an initialization query for the database file. This must be infallible
69 /// but may cause changes to the database file such as with `PRAGMA journal_mode`
70 pub fn with_db_initialization_query(mut self, initialize_query: &'static str) -> Self {
71 self.db_initialize_query = Some(initialize_query);
72 self
73 }
74
75 pub async fn build(self) -> ThreadSafeConnection<M> {
76 if !QUEUES.read().contains_key(&self.connection.uri) {
77 let mut queues = QUEUES.write();
78 if !queues.contains_key(&self.connection.uri) {
79 let mut write_connection = self.connection.create_connection();
80 // Enable writes for this connection
81 write_connection.write = true;
82 if let Some(mut write_queue_constructor) = self.write_queue_constructor {
83 let write_channel = write_queue_constructor(write_connection);
84 queues.insert(self.connection.uri.clone(), write_channel);
85 } else {
86 use std::sync::mpsc::channel;
87
88 let (sender, reciever) = channel::<QueuedWrite>();
89 thread::spawn(move || {
90 while let Ok(write) = reciever.recv() {
91 write(&write_connection)
92 }
93 });
94
95 let sender = UnboundedSyncSender::new(sender);
96 queues.insert(
97 self.connection.uri.clone(),
98 Box::new(move |queued_write| {
99 sender
100 .send(queued_write)
101 .expect("Could not send write action to backgorund thread");
102 }),
103 );
104 }
105 }
106 }
107
108 let db_initialize_query = self.db_initialize_query;
109
110 self.connection
111 .write(move |connection| {
112 if let Some(db_initialize_query) = db_initialize_query {
113 connection.exec(db_initialize_query).expect(&format!(
114 "Db initialize query failed to execute: {}",
115 db_initialize_query
116 ))()
117 .unwrap();
118 }
119
120 let mut failure_result = None;
121 for _ in 0..MIGRATION_RETRIES {
122 failure_result = Some(M::migrate(connection));
123 if failure_result.as_ref().unwrap().is_ok() {
124 break;
125 }
126 }
127
128 failure_result.unwrap().expect("Migration failed");
129 })
130 .await;
131
132 self.connection
133 }
134}
135
136impl<M: Migrator> ThreadSafeConnection<M> {
137 pub fn builder(uri: &str, persistent: bool) -> ThreadSafeConnectionBuilder<M> {
138 ThreadSafeConnectionBuilder::<M> {
139 db_initialize_query: None,
140 write_queue_constructor: None,
141 connection: Self {
142 uri: Arc::from(uri),
143 persistent,
144 connection_initialize_query: None,
145 connections: Default::default(),
146 _migrator: PhantomData,
147 },
148 }
149 }
150
151 /// Opens a new db connection with the initialized file path. This is internal and only
152 /// called from the deref function.
153 fn open_file(&self) -> Connection {
154 Connection::open_file(self.uri.as_ref())
155 }
156
157 /// Opens a shared memory connection using the file path as the identifier. This is internal
158 /// and only called from the deref function.
159 fn open_shared_memory(&self) -> Connection {
160 Connection::open_memory(Some(self.uri.as_ref()))
161 }
162
163 pub fn write<T: 'static + Send + Sync>(
164 &self,
165 callback: impl 'static + Send + FnOnce(&Connection) -> T,
166 ) -> impl Future<Output = T> {
167 let queues = QUEUES.read();
168 let write_channel = queues
169 .get(&self.uri)
170 .expect("Queues are inserted when build is called. This should always succeed");
171
172 // Create a one shot channel for the result of the queued write
173 // so we can await on the result
174 let (sender, reciever) = oneshot::channel();
175 write_channel(Box::new(move |connection| {
176 sender.send(callback(connection)).ok();
177 }));
178 reciever.map(|response| response.expect("Background writer thread unexpectedly closed"))
179 }
180
181 pub(crate) fn create_connection(&self) -> Connection {
182 let mut connection = if self.persistent {
183 self.open_file()
184 } else {
185 self.open_shared_memory()
186 };
187
188 // Disallow writes on the connection. The only writes allowed for thread safe connections
189 // are from the background thread that can serialize them.
190 connection.write = false;
191
192 if let Some(initialize_query) = self.connection_initialize_query {
193 connection.exec(initialize_query).expect(&format!(
194 "Initialize query failed to execute: {}",
195 initialize_query
196 ))()
197 .unwrap()
198 }
199
200 connection
201 }
202}
203
204impl ThreadSafeConnection<()> {
205 /// Special constructor for ThreadSafeConnection which disallows db initialization and migrations.
206 /// This allows construction to be infallible and not write to the db.
207 pub fn new(
208 uri: &str,
209 persistent: bool,
210 connection_initialize_query: Option<&'static str>,
211 ) -> Self {
212 Self {
213 uri: Arc::from(uri),
214 persistent,
215 connection_initialize_query,
216 connections: Default::default(),
217 _migrator: PhantomData,
218 }
219 }
220}
221
222impl<D: Domain> Clone for ThreadSafeConnection<D> {
223 fn clone(&self) -> Self {
224 Self {
225 uri: self.uri.clone(),
226 persistent: self.persistent,
227 connection_initialize_query: self.connection_initialize_query.clone(),
228 connections: self.connections.clone(),
229 _migrator: PhantomData,
230 }
231 }
232}
233
234// TODO:
235// 1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
236// 2. If the new db also fails, downgrade to a shared in memory db
237// 3. In either case notify the user about what went wrong
238impl<M: Migrator> Deref for ThreadSafeConnection<M> {
239 type Target = Connection;
240
241 fn deref(&self) -> &Self::Target {
242 self.connections.get_or(|| self.create_connection())
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use indoc::indoc;
249 use lazy_static::__Deref;
250 use std::thread;
251
252 use crate::{domain::Domain, thread_safe_connection::ThreadSafeConnection};
253
254 #[test]
255 fn many_initialize_and_migrate_queries_at_once() {
256 let mut handles = vec![];
257
258 enum TestDomain {}
259 impl Domain for TestDomain {
260 fn name() -> &'static str {
261 "test"
262 }
263 fn migrations() -> &'static [&'static str] {
264 &["CREATE TABLE test(col1 TEXT, col2 TEXT) STRICT;"]
265 }
266 }
267
268 for _ in 0..100 {
269 handles.push(thread::spawn(|| {
270 let builder =
271 ThreadSafeConnection::<TestDomain>::builder("annoying-test.db", false)
272 .with_db_initialization_query("PRAGMA journal_mode=WAL")
273 .with_connection_initialize_query(indoc! {"
274 PRAGMA synchronous=NORMAL;
275 PRAGMA busy_timeout=1;
276 PRAGMA foreign_keys=TRUE;
277 PRAGMA case_sensitive_like=TRUE;
278 "});
279 let _ = smol::block_on(builder.build()).deref();
280 }));
281 }
282
283 for handle in handles {
284 let _ = handle.join();
285 }
286 }
287
288 #[test]
289 #[should_panic]
290 fn wild_zed_lost_failure() {
291 enum TestWorkspace {}
292 impl Domain for TestWorkspace {
293 fn name() -> &'static str {
294 "workspace"
295 }
296
297 fn migrations() -> &'static [&'static str] {
298 &["
299 CREATE TABLE workspaces(
300 workspace_id INTEGER PRIMARY KEY,
301 dock_visible INTEGER, -- Boolean
302 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
303 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
304 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
305 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
306 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
307 ) STRICT;
308
309 CREATE TABLE panes(
310 pane_id INTEGER PRIMARY KEY,
311 workspace_id INTEGER NOT NULL,
312 active INTEGER NOT NULL, -- Boolean
313 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
314 ON DELETE CASCADE
315 ON UPDATE CASCADE
316 ) STRICT;
317 "]
318 }
319 }
320
321 let builder =
322 ThreadSafeConnection::<TestWorkspace>::builder("wild_zed_lost_failure", false)
323 .with_connection_initialize_query("PRAGMA FOREIGN_KEYS=true");
324
325 smol::block_on(builder.build());
326 }
327}