1use std::{
2 cell::RefCell,
3 ffi::{CStr, CString},
4 marker::PhantomData,
5 path::Path,
6 ptr,
7};
8
9use anyhow::Result;
10use libsqlite3_sys::*;
11
12pub struct Connection {
13 pub(crate) sqlite3: *mut sqlite3,
14 persistent: bool,
15 pub(crate) write: RefCell<bool>,
16 _sqlite: PhantomData<sqlite3>,
17}
18unsafe impl Send for Connection {}
19
20impl Connection {
21 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
22 let mut connection = Self {
23 sqlite3: ptr::null_mut(),
24 persistent,
25 write: RefCell::new(true),
26 _sqlite: PhantomData,
27 };
28
29 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
30 unsafe {
31 sqlite3_open_v2(
32 CString::new(uri)?.as_ptr(),
33 &mut connection.sqlite3,
34 flags,
35 ptr::null(),
36 );
37
38 // Turn on extended error codes
39 sqlite3_extended_result_codes(connection.sqlite3, 1);
40
41 connection.last_error()?;
42 }
43
44 Ok(connection)
45 }
46
47 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
48 /// instead.
49 pub fn open_file(uri: &str) -> Self {
50 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
51 }
52
53 pub fn open_memory(uri: Option<&str>) -> Self {
54 let in_memory_path = if let Some(uri) = uri {
55 format!("file:{}?mode=memory&cache=shared", uri)
56 } else {
57 ":memory:".to_string()
58 };
59
60 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
61 }
62
63 pub const fn persistent(&self) -> bool {
64 self.persistent
65 }
66
67 pub fn can_write(&self) -> bool {
68 *self.write.borrow()
69 }
70
71 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
72 unsafe {
73 let backup = sqlite3_backup_init(
74 destination.sqlite3,
75 CString::new("main")?.as_ptr(),
76 self.sqlite3,
77 CString::new("main")?.as_ptr(),
78 );
79 sqlite3_backup_step(backup, -1);
80 sqlite3_backup_finish(backup);
81 destination.last_error()
82 }
83 }
84
85 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
86 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
87 self.backup_main(&destination)
88 }
89
90 pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
91 let sql = CString::new(sql).unwrap();
92 let mut remaining_sql = sql.as_c_str();
93 let sql_start = remaining_sql.as_ptr();
94
95 let mut alter_table = None;
96 while {
97 let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
98 let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
99 if any_remaining_sql {
100 alter_table = parse_alter_table(remaining_sql_str);
101 }
102 any_remaining_sql
103 } {
104 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
105 let mut remaining_sql_ptr = ptr::null();
106
107 let (res, offset, message, _conn) = if let Some((table_to_alter, column)) = alter_table
108 {
109 // ALTER TABLE is a weird statement. When preparing the statement the table's
110 // existence is checked *before* syntax checking any other part of the statement.
111 // Therefore, we need to make sure that the table has been created before calling
112 // prepare. As we don't want to trash whatever database this is connected to, we
113 // create a new in-memory DB to test.
114
115 let temp_connection = Connection::open_memory(None);
116 //This should always succeed, if it doesn't then you really should know about it
117 temp_connection
118 .exec(&format!("CREATE TABLE {table_to_alter}({column})"))
119 .unwrap()()
120 .unwrap();
121
122 unsafe {
123 sqlite3_prepare_v2(
124 temp_connection.sqlite3,
125 remaining_sql.as_ptr(),
126 -1,
127 &mut raw_statement,
128 &mut remaining_sql_ptr,
129 )
130 };
131
132 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
133 let offset = unsafe { sqlite3_error_offset(temp_connection.sqlite3) };
134
135 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
136 let offset = 0;
137
138 unsafe {
139 (
140 sqlite3_errcode(temp_connection.sqlite3),
141 offset,
142 sqlite3_errmsg(temp_connection.sqlite3),
143 Some(temp_connection),
144 )
145 }
146 } else {
147 unsafe {
148 sqlite3_prepare_v2(
149 self.sqlite3,
150 remaining_sql.as_ptr(),
151 -1,
152 &mut raw_statement,
153 &mut remaining_sql_ptr,
154 )
155 };
156
157 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
158 let offset = unsafe { sqlite3_error_offset(self.sqlite3) };
159
160 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
161 let offset = 0;
162
163 unsafe {
164 (
165 sqlite3_errcode(self.sqlite3),
166 offset,
167 sqlite3_errmsg(self.sqlite3),
168 None,
169 )
170 }
171 };
172
173 unsafe { sqlite3_finalize(raw_statement) };
174
175 if res == 1 && offset >= 0 {
176 let sub_statement_correction = remaining_sql.as_ptr() as usize - sql_start as usize;
177 let err_msg = String::from_utf8_lossy(unsafe {
178 CStr::from_ptr(message as *const _).to_bytes()
179 })
180 .into_owned();
181
182 return Some((err_msg, offset as usize + sub_statement_correction));
183 }
184 remaining_sql = unsafe { CStr::from_ptr(remaining_sql_ptr) };
185 alter_table = None;
186 }
187 None
188 }
189
190 pub(crate) fn last_error(&self) -> Result<()> {
191 unsafe {
192 let code = sqlite3_errcode(self.sqlite3);
193 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
194 if NON_ERROR_CODES.contains(&code) {
195 return Ok(());
196 }
197
198 let message = sqlite3_errmsg(self.sqlite3);
199 let message = if message.is_null() {
200 None
201 } else {
202 Some(
203 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
204 .into_owned(),
205 )
206 };
207
208 anyhow::bail!("Sqlite call failed with code {code} and message: {message:?}")
209 }
210 }
211
212 pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
213 *self.write.borrow_mut() = true;
214 let result = callback(self);
215 *self.write.borrow_mut() = false;
216 result
217 }
218}
219
220fn parse_alter_table(remaining_sql_str: &str) -> Option<(String, String)> {
221 let remaining_sql_str = remaining_sql_str.to_lowercase();
222 if remaining_sql_str.starts_with("alter")
223 && let Some(table_offset) = remaining_sql_str.find("table")
224 {
225 let after_table_offset = table_offset + "table".len();
226 let table_to_alter = remaining_sql_str
227 .chars()
228 .skip(after_table_offset)
229 .skip_while(|c| c.is_whitespace())
230 .take_while(|c| !c.is_whitespace())
231 .collect::<String>();
232 if !table_to_alter.is_empty() {
233 let column_name = if let Some(rename_offset) = remaining_sql_str.find("rename column") {
234 let after_rename_offset = rename_offset + "rename column".len();
235 remaining_sql_str
236 .chars()
237 .skip(after_rename_offset)
238 .skip_while(|c| c.is_whitespace())
239 .take_while(|c| !c.is_whitespace())
240 .collect::<String>()
241 } else if let Some(drop_offset) = remaining_sql_str.find("drop column") {
242 let after_drop_offset = drop_offset + "drop column".len();
243 remaining_sql_str
244 .chars()
245 .skip(after_drop_offset)
246 .skip_while(|c| c.is_whitespace())
247 .take_while(|c| !c.is_whitespace())
248 .collect::<String>()
249 } else {
250 "__place_holder_column_for_syntax_checking".to_string()
251 };
252 return Some((table_to_alter, column_name));
253 }
254 }
255 None
256}
257
258impl Drop for Connection {
259 fn drop(&mut self) {
260 unsafe { sqlite3_close(self.sqlite3) };
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use anyhow::Result;
267 use indoc::indoc;
268
269 use crate::connection::Connection;
270
271 #[test]
272 fn string_round_trips() -> Result<()> {
273 let connection = Connection::open_memory(Some("string_round_trips"));
274 connection
275 .exec(indoc! {"
276 CREATE TABLE text (
277 text TEXT
278 );"})
279 .unwrap()()
280 .unwrap();
281
282 let text = "Some test text";
283
284 connection
285 .exec_bound("INSERT INTO text (text) VALUES (?);")
286 .unwrap()(text)
287 .unwrap();
288
289 assert_eq!(
290 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
291 Some(text.to_string())
292 );
293
294 Ok(())
295 }
296
297 #[test]
298 fn tuple_round_trips() {
299 let connection = Connection::open_memory(Some("tuple_round_trips"));
300 connection
301 .exec(indoc! {"
302 CREATE TABLE test (
303 text TEXT,
304 integer INTEGER,
305 blob BLOB
306 );"})
307 .unwrap()()
308 .unwrap();
309
310 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
311 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
312
313 let mut insert = connection
314 .exec_bound::<(String, usize, Vec<u8>)>(
315 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
316 )
317 .unwrap();
318
319 insert(tuple1.clone()).unwrap();
320 insert(tuple2.clone()).unwrap();
321
322 assert_eq!(
323 connection
324 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
325 .unwrap()()
326 .unwrap(),
327 vec![tuple1, tuple2]
328 );
329 }
330
331 #[test]
332 fn bool_round_trips() {
333 let connection = Connection::open_memory(Some("bool_round_trips"));
334 connection
335 .exec(indoc! {"
336 CREATE TABLE bools (
337 t INTEGER,
338 f INTEGER
339 );"})
340 .unwrap()()
341 .unwrap();
342
343 connection
344 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
345 .unwrap()((true, false))
346 .unwrap();
347
348 assert_eq!(
349 connection
350 .select_row::<(bool, bool)>("SELECT * FROM bools;")
351 .unwrap()()
352 .unwrap(),
353 Some((true, false))
354 );
355 }
356
357 #[test]
358 fn backup_works() {
359 let connection1 = Connection::open_memory(Some("backup_works"));
360 connection1
361 .exec(indoc! {"
362 CREATE TABLE blobs (
363 data BLOB
364 );"})
365 .unwrap()()
366 .unwrap();
367 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
368 connection1
369 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
370 .unwrap()(blob.clone())
371 .unwrap();
372
373 // Backup connection1 to connection2
374 let connection2 = Connection::open_memory(Some("backup_works_other"));
375 connection1.backup_main(&connection2).unwrap();
376
377 // Delete the added blob and verify its deleted on the other side
378 let read_blobs = connection1
379 .select::<Vec<u8>>("SELECT * FROM blobs;")
380 .unwrap()()
381 .unwrap();
382 assert_eq!(read_blobs, vec![blob]);
383 }
384
385 #[test]
386 fn multi_step_statement_works() {
387 let connection = Connection::open_memory(Some("multi_step_statement_works"));
388
389 connection
390 .exec(indoc! {"
391 CREATE TABLE test (
392 col INTEGER
393 )"})
394 .unwrap()()
395 .unwrap();
396
397 connection
398 .exec(indoc! {"
399 INSERT INTO test(col) VALUES (2)"})
400 .unwrap()()
401 .unwrap();
402
403 assert_eq!(
404 connection
405 .select_row::<usize>("SELECT * FROM test")
406 .unwrap()()
407 .unwrap(),
408 Some(2)
409 );
410 }
411
412 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
413 #[test]
414 fn test_sql_has_syntax_errors() {
415 let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
416 let first_stmt =
417 "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
418 let second_stmt = "SELECT FROM";
419
420 let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
421
422 let res = connection
423 .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
424 .map(|(_, offset)| offset);
425
426 assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
427 }
428
429 #[test]
430 fn test_alter_table_syntax() {
431 let connection = Connection::open_memory(Some("test_alter_table_syntax"));
432
433 assert!(
434 connection
435 .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
436 .is_none()
437 );
438
439 assert!(
440 connection
441 .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
442 .is_some()
443 );
444 }
445}