1use std::{
2 cell::RefCell,
3 ffi::{CStr, CString},
4 marker::PhantomData,
5 path::Path,
6 ptr,
7};
8
9use anyhow::{anyhow, Result};
10use libsqlite3_sys::*;
11
12pub struct Connection {
13 pub(crate) sqlite3: *mut sqlite3,
14 persistent: bool,
15 pub(crate) write: RefCell<bool>,
16 _sqlite: PhantomData<sqlite3>,
17}
18unsafe impl Send for Connection {}
19
20impl Connection {
21 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
22 let mut connection = Self {
23 sqlite3: 0 as *mut _,
24 persistent,
25 write: RefCell::new(true),
26 _sqlite: PhantomData,
27 };
28
29 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
30 unsafe {
31 sqlite3_open_v2(
32 CString::new(uri)?.as_ptr(),
33 &mut connection.sqlite3,
34 flags,
35 0 as *const _,
36 );
37
38 // Turn on extended error codes
39 sqlite3_extended_result_codes(connection.sqlite3, 1);
40
41 connection.last_error()?;
42 }
43
44 Ok(connection)
45 }
46
47 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
48 /// instead.
49 pub fn open_file(uri: &str) -> Self {
50 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
51 }
52
53 pub fn open_memory(uri: Option<&str>) -> Self {
54 let in_memory_path = if let Some(uri) = uri {
55 format!("file:{}?mode=memory&cache=shared", uri)
56 } else {
57 ":memory:".to_string()
58 };
59
60 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
61 }
62
63 pub fn persistent(&self) -> bool {
64 self.persistent
65 }
66
67 pub fn can_write(&self) -> bool {
68 *self.write.borrow()
69 }
70
71 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
72 unsafe {
73 let backup = sqlite3_backup_init(
74 destination.sqlite3,
75 CString::new("main")?.as_ptr(),
76 self.sqlite3,
77 CString::new("main")?.as_ptr(),
78 );
79 sqlite3_backup_step(backup, -1);
80 sqlite3_backup_finish(backup);
81 destination.last_error()
82 }
83 }
84
85 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
86 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
87 self.backup_main(&destination)
88 }
89
90 pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
91 let sql = CString::new(sql).unwrap();
92 let mut remaining_sql = sql.as_c_str();
93 let sql_start = remaining_sql.as_ptr();
94
95 unsafe {
96 while {
97 let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
98 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
99 } {
100 let mut raw_statement = 0 as *mut sqlite3_stmt;
101 let mut remaining_sql_ptr = ptr::null();
102 sqlite3_prepare_v2(
103 self.sqlite3,
104 remaining_sql.as_ptr(),
105 -1,
106 &mut raw_statement,
107 &mut remaining_sql_ptr,
108 );
109
110 let res = sqlite3_errcode(self.sqlite3);
111 let offset = sqlite3_error_offset(self.sqlite3);
112 let message = sqlite3_errmsg(self.sqlite3);
113
114 sqlite3_finalize(raw_statement);
115
116 if res == 1 && offset >= 0 {
117 let err_msg =
118 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
119 .into_owned();
120 let sub_statement_correction =
121 remaining_sql.as_ptr() as usize - sql_start as usize;
122
123 return Some((err_msg, offset as usize + sub_statement_correction));
124 }
125 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
126 }
127 }
128 None
129 }
130
131 pub(crate) fn last_error(&self) -> Result<()> {
132 unsafe {
133 let code = sqlite3_errcode(self.sqlite3);
134 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
135 if NON_ERROR_CODES.contains(&code) {
136 return Ok(());
137 }
138
139 let message = sqlite3_errmsg(self.sqlite3);
140 let message = if message.is_null() {
141 None
142 } else {
143 Some(
144 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
145 .into_owned(),
146 )
147 };
148
149 Err(anyhow!(
150 "Sqlite call failed with code {} and message: {:?}",
151 code as isize,
152 message
153 ))
154 }
155 }
156
157 pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
158 *self.write.borrow_mut() = true;
159 let result = callback(self);
160 *self.write.borrow_mut() = false;
161 result
162 }
163}
164
165impl Drop for Connection {
166 fn drop(&mut self) {
167 unsafe { sqlite3_close(self.sqlite3) };
168 }
169}
170
171#[cfg(test)]
172mod test {
173 use anyhow::Result;
174 use indoc::indoc;
175
176 use crate::connection::Connection;
177
178 #[test]
179 fn string_round_trips() -> Result<()> {
180 let connection = Connection::open_memory(Some("string_round_trips"));
181 connection
182 .exec(indoc! {"
183 CREATE TABLE text (
184 text TEXT
185 );"})
186 .unwrap()()
187 .unwrap();
188
189 let text = "Some test text";
190
191 connection
192 .exec_bound("INSERT INTO text (text) VALUES (?);")
193 .unwrap()(text)
194 .unwrap();
195
196 assert_eq!(
197 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
198 Some(text.to_string())
199 );
200
201 Ok(())
202 }
203
204 #[test]
205 fn tuple_round_trips() {
206 let connection = Connection::open_memory(Some("tuple_round_trips"));
207 connection
208 .exec(indoc! {"
209 CREATE TABLE test (
210 text TEXT,
211 integer INTEGER,
212 blob BLOB
213 );"})
214 .unwrap()()
215 .unwrap();
216
217 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
218 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
219
220 let mut insert = connection
221 .exec_bound::<(String, usize, Vec<u8>)>(
222 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
223 )
224 .unwrap();
225
226 insert(tuple1.clone()).unwrap();
227 insert(tuple2.clone()).unwrap();
228
229 assert_eq!(
230 connection
231 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
232 .unwrap()()
233 .unwrap(),
234 vec![tuple1, tuple2]
235 );
236 }
237
238 #[test]
239 fn bool_round_trips() {
240 let connection = Connection::open_memory(Some("bool_round_trips"));
241 connection
242 .exec(indoc! {"
243 CREATE TABLE bools (
244 t INTEGER,
245 f INTEGER
246 );"})
247 .unwrap()()
248 .unwrap();
249
250 connection
251 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
252 .unwrap()((true, false))
253 .unwrap();
254
255 assert_eq!(
256 connection
257 .select_row::<(bool, bool)>("SELECT * FROM bools;")
258 .unwrap()()
259 .unwrap(),
260 Some((true, false))
261 );
262 }
263
264 #[test]
265 fn backup_works() {
266 let connection1 = Connection::open_memory(Some("backup_works"));
267 connection1
268 .exec(indoc! {"
269 CREATE TABLE blobs (
270 data BLOB
271 );"})
272 .unwrap()()
273 .unwrap();
274 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
275 connection1
276 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
277 .unwrap()(blob.clone())
278 .unwrap();
279
280 // Backup connection1 to connection2
281 let connection2 = Connection::open_memory(Some("backup_works_other"));
282 connection1.backup_main(&connection2).unwrap();
283
284 // Delete the added blob and verify its deleted on the other side
285 let read_blobs = connection1
286 .select::<Vec<u8>>("SELECT * FROM blobs;")
287 .unwrap()()
288 .unwrap();
289 assert_eq!(read_blobs, vec![blob]);
290 }
291
292 #[test]
293 fn multi_step_statement_works() {
294 let connection = Connection::open_memory(Some("multi_step_statement_works"));
295
296 connection
297 .exec(indoc! {"
298 CREATE TABLE test (
299 col INTEGER
300 )"})
301 .unwrap()()
302 .unwrap();
303
304 connection
305 .exec(indoc! {"
306 INSERT INTO test(col) VALUES (2)"})
307 .unwrap()()
308 .unwrap();
309
310 assert_eq!(
311 connection
312 .select_row::<usize>("SELECT * FROM test")
313 .unwrap()()
314 .unwrap(),
315 Some(2)
316 );
317 }
318
319 #[test]
320 fn test_sql_has_syntax_errors() {
321 let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
322 let first_stmt =
323 "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
324 let second_stmt = "SELECT FROM";
325
326 let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
327
328 let res = connection
329 .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
330 .map(|(_, offset)| offset);
331
332 assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
333 }
334}