1use futures::{Future, FutureExt};
2use lazy_static::lazy_static;
3use parking_lot::RwLock;
4use std::{collections::HashMap, marker::PhantomData, ops::Deref, sync::Arc, thread};
5use thread_local::ThreadLocal;
6
7use crate::{
8 connection::Connection,
9 domain::{Domain, Migrator},
10 util::UnboundedSyncSender,
11};
12
13type QueuedWrite = Box<dyn 'static + Send + FnOnce(&Connection)>;
14
15lazy_static! {
16 static ref QUEUES: RwLock<HashMap<Arc<str>, UnboundedSyncSender<QueuedWrite>>> =
17 Default::default();
18}
19
20pub struct ThreadSafeConnection<M: Migrator> {
21 uri: Arc<str>,
22 persistent: bool,
23 initialize_query: Option<&'static str>,
24 connections: Arc<ThreadLocal<Connection>>,
25 _migrator: PhantomData<M>,
26}
27
28unsafe impl<T: Migrator> Send for ThreadSafeConnection<T> {}
29unsafe impl<T: Migrator> Sync for ThreadSafeConnection<T> {}
30
31impl<M: Migrator> ThreadSafeConnection<M> {
32 pub fn new(uri: &str, persistent: bool) -> Self {
33 Self {
34 uri: Arc::from(uri),
35 persistent,
36 initialize_query: None,
37 connections: Default::default(),
38 _migrator: PhantomData,
39 }
40 }
41
42 /// Sets the query to run every time a connection is opened. This must
43 /// be infallible (EG only use pragma statements)
44 pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
45 self.initialize_query = Some(initialize_query);
46 self
47 }
48
49 /// Opens a new db connection with the initialized file path. This is internal and only
50 /// called from the deref function.
51 /// If opening fails, the connection falls back to a shared memory connection
52 fn open_file(&self) -> Connection {
53 // This unwrap is secured by a panic in the constructor. Be careful if you remove it!
54 Connection::open_file(self.uri.as_ref())
55 }
56
57 /// Opens a shared memory connection using the file path as the identifier. This unwraps
58 /// as we expect it always to succeed
59 fn open_shared_memory(&self) -> Connection {
60 Connection::open_memory(Some(self.uri.as_ref()))
61 }
62
63 // Open a new connection for the given domain, leaving this
64 // connection intact.
65 pub fn for_domain<D2: Domain>(&self) -> ThreadSafeConnection<D2> {
66 ThreadSafeConnection {
67 uri: self.uri.clone(),
68 persistent: self.persistent,
69 initialize_query: self.initialize_query,
70 connections: Default::default(),
71 _migrator: PhantomData,
72 }
73 }
74
75 pub fn write<T: 'static + Send + Sync>(
76 &self,
77 callback: impl 'static + Send + FnOnce(&Connection) -> T,
78 ) -> impl Future<Output = T> {
79 // Startup write thread for this database if one hasn't already
80 // been started and insert a channel to queue work for it
81 if !QUEUES.read().contains_key(&self.uri) {
82 use std::sync::mpsc::channel;
83
84 let (sender, reciever) = channel::<QueuedWrite>();
85 let mut write_connection = self.create_connection();
86 // Enable writes for this connection
87 write_connection.write = true;
88 thread::spawn(move || {
89 while let Ok(write) = reciever.recv() {
90 write(&write_connection)
91 }
92 });
93
94 let mut queues = QUEUES.write();
95 queues.insert(self.uri.clone(), UnboundedSyncSender::new(sender));
96 }
97
98 // Grab the queue for this database
99 let queues = QUEUES.read();
100 let write_channel = queues.get(&self.uri).unwrap();
101
102 // Create a one shot channel for the result of the queued write
103 // so we can await on the result
104 let (sender, reciever) = futures::channel::oneshot::channel();
105 write_channel
106 .send(Box::new(move |connection| {
107 sender.send(callback(connection)).ok();
108 }))
109 .expect("Could not send write action to background thread");
110
111 reciever.map(|response| response.expect("Background thread unexpectedly closed"))
112 }
113
114 pub(crate) fn create_connection(&self) -> Connection {
115 let mut connection = if self.persistent {
116 self.open_file()
117 } else {
118 self.open_shared_memory()
119 };
120
121 // Enable writes for the migrations and initialization queries
122 connection.write = true;
123
124 if let Some(initialize_query) = self.initialize_query {
125 connection.exec(initialize_query).expect(&format!(
126 "Initialize query failed to execute: {}",
127 initialize_query
128 ))()
129 .unwrap();
130 }
131
132 M::migrate(&connection).expect("Migrations failed");
133
134 // Disable db writes for normal thread local connection
135 connection.write = false;
136 connection
137 }
138}
139
140impl<D: Domain> Clone for ThreadSafeConnection<D> {
141 fn clone(&self) -> Self {
142 Self {
143 uri: self.uri.clone(),
144 persistent: self.persistent,
145 initialize_query: self.initialize_query.clone(),
146 connections: self.connections.clone(),
147 _migrator: PhantomData,
148 }
149 }
150}
151
152// TODO:
153// 1. When migration or initialization fails, move the corrupted db to a holding place and create a new one
154// 2. If the new db also fails, downgrade to a shared in memory db
155// 3. In either case notify the user about what went wrong
156impl<M: Migrator> Deref for ThreadSafeConnection<M> {
157 type Target = Connection;
158
159 fn deref(&self) -> &Self::Target {
160 self.connections.get_or(|| self.create_connection())
161 }
162}
163
164#[cfg(test)]
165mod test {
166 use std::ops::Deref;
167
168 use crate::domain::Domain;
169
170 use super::ThreadSafeConnection;
171
172 #[test]
173 #[should_panic]
174 fn wild_zed_lost_failure() {
175 enum TestWorkspace {}
176 impl Domain for TestWorkspace {
177 fn name() -> &'static str {
178 "workspace"
179 }
180
181 fn migrations() -> &'static [&'static str] {
182 &["
183 CREATE TABLE workspaces(
184 workspace_id INTEGER PRIMARY KEY,
185 dock_visible INTEGER, -- Boolean
186 dock_anchor TEXT, -- Enum: 'Bottom' / 'Right' / 'Expanded'
187 dock_pane INTEGER, -- NULL indicates that we don't have a dock pane yet
188 timestamp TEXT DEFAULT CURRENT_TIMESTAMP NOT NULL,
189 FOREIGN KEY(dock_pane) REFERENCES panes(pane_id),
190 FOREIGN KEY(active_pane) REFERENCES panes(pane_id)
191 ) STRICT;
192
193 CREATE TABLE panes(
194 pane_id INTEGER PRIMARY KEY,
195 workspace_id INTEGER NOT NULL,
196 active INTEGER NOT NULL, -- Boolean
197 FOREIGN KEY(workspace_id) REFERENCES workspaces(workspace_id)
198 ON DELETE CASCADE
199 ON UPDATE CASCADE
200 ) STRICT;
201 "]
202 }
203 }
204
205 let _ = ThreadSafeConnection::<TestWorkspace>::new("wild_zed_lost_failure", false)
206 .with_initialize_query("PRAGMA FOREIGN_KEYS=true")
207 .deref();
208 }
209}