1use std::{
2 cell::RefCell,
3 ffi::{CStr, CString},
4 marker::PhantomData,
5 path::Path,
6 ptr,
7};
8
9use anyhow::Result;
10use libsqlite3_sys::*;
11
12pub struct Connection {
13 pub(crate) sqlite3: *mut sqlite3,
14 persistent: bool,
15 pub(crate) write: RefCell<bool>,
16 _sqlite: PhantomData<sqlite3>,
17}
18unsafe impl Send for Connection {}
19
20impl Connection {
21 fn open_with_flags(uri: &str, persistent: bool, flags: i32) -> Result<Self> {
22 let mut connection = Self {
23 sqlite3: ptr::null_mut(),
24 persistent,
25 write: RefCell::new(true),
26 _sqlite: PhantomData,
27 };
28
29 unsafe {
30 sqlite3_open_v2(
31 CString::new(uri)?.as_ptr(),
32 &mut connection.sqlite3,
33 flags,
34 ptr::null(),
35 );
36
37 // Turn on extended error codes
38 sqlite3_extended_result_codes(connection.sqlite3, 1);
39
40 connection.last_error()?;
41 }
42
43 Ok(connection)
44 }
45
46 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
47 Self::open_with_flags(
48 uri,
49 persistent,
50 SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE,
51 )
52 }
53
54 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
55 /// instead.
56 pub fn open_file(uri: &str) -> Self {
57 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
58 }
59
60 pub fn open_memory(uri: Option<&str>) -> Self {
61 if let Some(uri) = uri {
62 let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
63 return Self::open_with_flags(
64 &in_memory_path,
65 false,
66 SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE | SQLITE_OPEN_URI,
67 )
68 .expect("Could not create fallback in memory db");
69 } else {
70 Self::open(":memory:", false).expect("Could not create fallback in memory db")
71 }
72 }
73
74 pub fn persistent(&self) -> bool {
75 self.persistent
76 }
77
78 pub fn can_write(&self) -> bool {
79 *self.write.borrow()
80 }
81
82 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
83 unsafe {
84 let backup = sqlite3_backup_init(
85 destination.sqlite3,
86 CString::new("main")?.as_ptr(),
87 self.sqlite3,
88 CString::new("main")?.as_ptr(),
89 );
90 sqlite3_backup_step(backup, -1);
91 sqlite3_backup_finish(backup);
92 destination.last_error()
93 }
94 }
95
96 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
97 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
98 self.backup_main(&destination)
99 }
100
101 pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
102 let sql = CString::new(sql).unwrap();
103 let mut remaining_sql = sql.as_c_str();
104 let sql_start = remaining_sql.as_ptr();
105
106 let mut alter_table = None;
107 while {
108 let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
109 let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
110 if any_remaining_sql {
111 alter_table = parse_alter_table(remaining_sql_str);
112 }
113 any_remaining_sql
114 } {
115 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
116 let mut remaining_sql_ptr = ptr::null();
117
118 let (res, offset, message, _conn) = if let Some((table_to_alter, column)) = alter_table
119 {
120 // ALTER TABLE is a weird statement. When preparing the statement the table's
121 // existence is checked *before* syntax checking any other part of the statement.
122 // Therefore, we need to make sure that the table has been created before calling
123 // prepare. As we don't want to trash whatever database this is connected to, we
124 // create a new in-memory DB to test.
125
126 let temp_connection = Connection::open_memory(None);
127 //This should always succeed, if it doesn't then you really should know about it
128 temp_connection
129 .exec(&format!("CREATE TABLE {table_to_alter}({column})"))
130 .unwrap()()
131 .unwrap();
132
133 unsafe {
134 sqlite3_prepare_v2(
135 temp_connection.sqlite3,
136 remaining_sql.as_ptr(),
137 -1,
138 &mut raw_statement,
139 &mut remaining_sql_ptr,
140 )
141 };
142
143 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
144 let offset = unsafe { sqlite3_error_offset(temp_connection.sqlite3) };
145
146 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
147 let offset = 0;
148
149 unsafe {
150 (
151 sqlite3_errcode(temp_connection.sqlite3),
152 offset,
153 sqlite3_errmsg(temp_connection.sqlite3),
154 Some(temp_connection),
155 )
156 }
157 } else {
158 unsafe {
159 sqlite3_prepare_v2(
160 self.sqlite3,
161 remaining_sql.as_ptr(),
162 -1,
163 &mut raw_statement,
164 &mut remaining_sql_ptr,
165 )
166 };
167
168 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
169 let offset = unsafe { sqlite3_error_offset(self.sqlite3) };
170
171 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
172 let offset = 0;
173
174 unsafe {
175 (
176 sqlite3_errcode(self.sqlite3),
177 offset,
178 sqlite3_errmsg(self.sqlite3),
179 None,
180 )
181 }
182 };
183
184 unsafe { sqlite3_finalize(raw_statement) };
185
186 if res == 1 && offset >= 0 {
187 let sub_statement_correction = remaining_sql.as_ptr() as usize - sql_start as usize;
188 let err_msg = String::from_utf8_lossy(unsafe {
189 CStr::from_ptr(message as *const _).to_bytes()
190 })
191 .into_owned();
192
193 return Some((err_msg, offset as usize + sub_statement_correction));
194 }
195 remaining_sql = unsafe { CStr::from_ptr(remaining_sql_ptr) };
196 alter_table = None;
197 }
198 None
199 }
200
201 pub(crate) fn last_error(&self) -> Result<()> {
202 unsafe {
203 let code = sqlite3_errcode(self.sqlite3);
204 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
205 if NON_ERROR_CODES.contains(&code) {
206 return Ok(());
207 }
208
209 let message = sqlite3_errmsg(self.sqlite3);
210 let message = if message.is_null() {
211 None
212 } else {
213 Some(
214 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
215 .into_owned(),
216 )
217 };
218
219 anyhow::bail!("Sqlite call failed with code {code} and message: {message:?}")
220 }
221 }
222
223 pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
224 *self.write.borrow_mut() = true;
225 let result = callback(self);
226 *self.write.borrow_mut() = false;
227 result
228 }
229}
230
231fn parse_alter_table(remaining_sql_str: &str) -> Option<(String, String)> {
232 let remaining_sql_str = remaining_sql_str.to_lowercase();
233 if remaining_sql_str.starts_with("alter")
234 && let Some(table_offset) = remaining_sql_str.find("table")
235 {
236 let after_table_offset = table_offset + "table".len();
237 let table_to_alter = remaining_sql_str
238 .chars()
239 .skip(after_table_offset)
240 .skip_while(|c| c.is_whitespace())
241 .take_while(|c| !c.is_whitespace())
242 .collect::<String>();
243 if !table_to_alter.is_empty() {
244 let column_name = if let Some(rename_offset) = remaining_sql_str.find("rename column") {
245 let after_rename_offset = rename_offset + "rename column".len();
246 remaining_sql_str
247 .chars()
248 .skip(after_rename_offset)
249 .skip_while(|c| c.is_whitespace())
250 .take_while(|c| !c.is_whitespace())
251 .collect::<String>()
252 } else if let Some(drop_offset) = remaining_sql_str.find("drop column") {
253 let after_drop_offset = drop_offset + "drop column".len();
254 remaining_sql_str
255 .chars()
256 .skip(after_drop_offset)
257 .skip_while(|c| c.is_whitespace())
258 .take_while(|c| !c.is_whitespace())
259 .collect::<String>()
260 } else {
261 "__place_holder_column_for_syntax_checking".to_string()
262 };
263 return Some((table_to_alter, column_name));
264 }
265 }
266 None
267}
268
269impl Drop for Connection {
270 fn drop(&mut self) {
271 unsafe { sqlite3_close(self.sqlite3) };
272 }
273}
274
275#[cfg(test)]
276mod test {
277 use anyhow::Result;
278 use indoc::indoc;
279 use std::{
280 fs,
281 sync::atomic::{AtomicUsize, Ordering},
282 };
283
284 use crate::connection::Connection;
285
286 static NEXT_NAMED_MEMORY_DB_ID: AtomicUsize = AtomicUsize::new(0);
287
288 fn unique_named_memory_db(prefix: &str) -> String {
289 format!(
290 "{prefix}_{}_{}",
291 std::process::id(),
292 NEXT_NAMED_MEMORY_DB_ID.fetch_add(1, Ordering::Relaxed)
293 )
294 }
295
296 fn literal_named_memory_paths(name: &str) -> [String; 3] {
297 let main = format!("file:{name}?mode=memory&cache=shared");
298 [main.clone(), format!("{main}-wal"), format!("{main}-shm")]
299 }
300
301 struct NamedMemoryPathGuard {
302 paths: [String; 3],
303 }
304
305 impl NamedMemoryPathGuard {
306 fn new(name: &str) -> Self {
307 let paths = literal_named_memory_paths(name);
308 for path in &paths {
309 let _ = fs::remove_file(path);
310 }
311 Self { paths }
312 }
313 }
314
315 impl Drop for NamedMemoryPathGuard {
316 fn drop(&mut self) {
317 for path in &self.paths {
318 let _ = fs::remove_file(path);
319 }
320 }
321 }
322
323 #[test]
324 fn string_round_trips() -> Result<()> {
325 let connection = Connection::open_memory(Some("string_round_trips"));
326 connection
327 .exec(indoc! {"
328 CREATE TABLE text (
329 text TEXT
330 );"})
331 .unwrap()()
332 .unwrap();
333
334 let text = "Some test text";
335
336 connection
337 .exec_bound("INSERT INTO text (text) VALUES (?);")
338 .unwrap()(text)
339 .unwrap();
340
341 assert_eq!(
342 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
343 Some(text.to_string())
344 );
345
346 Ok(())
347 }
348
349 #[test]
350 fn tuple_round_trips() {
351 let connection = Connection::open_memory(Some("tuple_round_trips"));
352 connection
353 .exec(indoc! {"
354 CREATE TABLE test (
355 text TEXT,
356 integer INTEGER,
357 blob BLOB
358 );"})
359 .unwrap()()
360 .unwrap();
361
362 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
363 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
364
365 let mut insert = connection
366 .exec_bound::<(String, usize, Vec<u8>)>(
367 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
368 )
369 .unwrap();
370
371 insert(tuple1.clone()).unwrap();
372 insert(tuple2.clone()).unwrap();
373
374 assert_eq!(
375 connection
376 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
377 .unwrap()()
378 .unwrap(),
379 vec![tuple1, tuple2]
380 );
381 }
382
383 #[test]
384 fn bool_round_trips() {
385 let connection = Connection::open_memory(Some("bool_round_trips"));
386 connection
387 .exec(indoc! {"
388 CREATE TABLE bools (
389 t INTEGER,
390 f INTEGER
391 );"})
392 .unwrap()()
393 .unwrap();
394
395 connection
396 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
397 .unwrap()((true, false))
398 .unwrap();
399
400 assert_eq!(
401 connection
402 .select_row::<(bool, bool)>("SELECT * FROM bools;")
403 .unwrap()()
404 .unwrap(),
405 Some((true, false))
406 );
407 }
408
409 #[test]
410 fn backup_works() {
411 let connection1 = Connection::open_memory(Some("backup_works"));
412 connection1
413 .exec(indoc! {"
414 CREATE TABLE blobs (
415 data BLOB
416 );"})
417 .unwrap()()
418 .unwrap();
419 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
420 connection1
421 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
422 .unwrap()(blob.clone())
423 .unwrap();
424
425 // Backup connection1 to connection2
426 let connection2 = Connection::open_memory(Some("backup_works_other"));
427 connection1.backup_main(&connection2).unwrap();
428
429 // Delete the added blob and verify its deleted on the other side
430 let read_blobs = connection1
431 .select::<Vec<u8>>("SELECT * FROM blobs;")
432 .unwrap()()
433 .unwrap();
434 assert_eq!(read_blobs, vec![blob]);
435 }
436
437 #[test]
438 fn named_memory_connections_do_not_create_literal_backing_files() {
439 let name = unique_named_memory_db("named_memory_connections_do_not_create_backing_files");
440 let guard = NamedMemoryPathGuard::new(&name);
441
442 let connection1 = Connection::open_memory(Some(&name));
443 connection1
444 .exec(indoc! {"
445 CREATE TABLE shared (
446 value INTEGER
447 )"})
448 .unwrap()()
449 .unwrap();
450 connection1
451 .exec("INSERT INTO shared (value) VALUES (7)")
452 .unwrap()()
453 .unwrap();
454
455 let connection2 = Connection::open_memory(Some(&name));
456 assert_eq!(
457 connection2
458 .select_row::<i64>("SELECT value FROM shared")
459 .unwrap()()
460 .unwrap(),
461 Some(7)
462 );
463
464 for path in &guard.paths {
465 assert!(
466 fs::metadata(path).is_err(),
467 "named in-memory database unexpectedly created backing file {path}"
468 );
469 }
470 }
471
472 #[test]
473 fn multi_step_statement_works() {
474 let connection = Connection::open_memory(Some("multi_step_statement_works"));
475
476 connection
477 .exec(indoc! {"
478 CREATE TABLE test (
479 col INTEGER
480 )"})
481 .unwrap()()
482 .unwrap();
483
484 connection
485 .exec(indoc! {"
486 INSERT INTO test(col) VALUES (2)"})
487 .unwrap()()
488 .unwrap();
489
490 assert_eq!(
491 connection
492 .select_row::<usize>("SELECT * FROM test")
493 .unwrap()()
494 .unwrap(),
495 Some(2)
496 );
497 }
498
499 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
500 #[test]
501 fn test_sql_has_syntax_errors() {
502 let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
503 let first_stmt =
504 "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
505 let second_stmt = "SELECT FROM";
506
507 let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
508
509 let res = connection
510 .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
511 .map(|(_, offset)| offset);
512
513 assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
514 }
515
516 #[test]
517 fn test_alter_table_syntax() {
518 let connection = Connection::open_memory(Some("test_alter_table_syntax"));
519
520 assert!(
521 connection
522 .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
523 .is_none()
524 );
525
526 assert!(
527 connection
528 .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
529 .is_some()
530 );
531 }
532}