1use std::{
2 cell::RefCell,
3 ffi::{CStr, CString},
4 marker::PhantomData,
5 path::Path,
6 ptr,
7};
8
9use anyhow::{Result, anyhow};
10use libsqlite3_sys::*;
11
12pub struct Connection {
13 pub(crate) sqlite3: *mut sqlite3,
14 persistent: bool,
15 pub(crate) write: RefCell<bool>,
16 _sqlite: PhantomData<sqlite3>,
17}
18unsafe impl Send for Connection {}
19
20impl Connection {
21 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
22 let mut connection = Self {
23 sqlite3: ptr::null_mut(),
24 persistent,
25 write: RefCell::new(true),
26 _sqlite: PhantomData,
27 };
28
29 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
30 unsafe {
31 sqlite3_open_v2(
32 CString::new(uri)?.as_ptr(),
33 &mut connection.sqlite3,
34 flags,
35 ptr::null(),
36 );
37
38 // Turn on extended error codes
39 sqlite3_extended_result_codes(connection.sqlite3, 1);
40
41 connection.last_error()?;
42 }
43
44 Ok(connection)
45 }
46
47 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
48 /// instead.
49 pub fn open_file(uri: &str) -> Self {
50 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
51 }
52
53 pub fn open_memory(uri: Option<&str>) -> Self {
54 let in_memory_path = if let Some(uri) = uri {
55 format!("file:{}?mode=memory&cache=shared", uri)
56 } else {
57 ":memory:".to_string()
58 };
59
60 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
61 }
62
63 pub fn persistent(&self) -> bool {
64 self.persistent
65 }
66
67 pub fn can_write(&self) -> bool {
68 *self.write.borrow()
69 }
70
71 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
72 unsafe {
73 let backup = sqlite3_backup_init(
74 destination.sqlite3,
75 CString::new("main")?.as_ptr(),
76 self.sqlite3,
77 CString::new("main")?.as_ptr(),
78 );
79 sqlite3_backup_step(backup, -1);
80 sqlite3_backup_finish(backup);
81 destination.last_error()
82 }
83 }
84
85 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
86 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
87 self.backup_main(&destination)
88 }
89
90 pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
91 let sql = CString::new(sql).unwrap();
92 let mut remaining_sql = sql.as_c_str();
93 let sql_start = remaining_sql.as_ptr();
94
95 unsafe {
96 let mut alter_table = None;
97 while {
98 let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
99 let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
100 if any_remaining_sql {
101 alter_table = parse_alter_table(remaining_sql_str);
102 }
103 any_remaining_sql
104 } {
105 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
106 let mut remaining_sql_ptr = ptr::null();
107
108 let (res, offset, message, _conn) =
109 if let Some((table_to_alter, column)) = alter_table {
110 // ALTER TABLE is a weird statement. When preparing the statement the table's
111 // existence is checked *before* syntax checking any other part of the statement.
112 // Therefore, we need to make sure that the table has been created before calling
113 // prepare. As we don't want to trash whatever database this is connected to, we
114 // create a new in-memory DB to test.
115
116 let temp_connection = Connection::open_memory(None);
117 //This should always succeed, if it doesn't then you really should know about it
118 temp_connection
119 .exec(&format!("CREATE TABLE {table_to_alter}({column})"))
120 .unwrap()()
121 .unwrap();
122
123 sqlite3_prepare_v2(
124 temp_connection.sqlite3,
125 remaining_sql.as_ptr(),
126 -1,
127 &mut raw_statement,
128 &mut remaining_sql_ptr,
129 );
130
131 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
132 let offset = sqlite3_error_offset(temp_connection.sqlite3);
133
134 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
135 let offset = 0;
136
137 (
138 sqlite3_errcode(temp_connection.sqlite3),
139 offset,
140 sqlite3_errmsg(temp_connection.sqlite3),
141 Some(temp_connection),
142 )
143 } else {
144 sqlite3_prepare_v2(
145 self.sqlite3,
146 remaining_sql.as_ptr(),
147 -1,
148 &mut raw_statement,
149 &mut remaining_sql_ptr,
150 );
151
152 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
153 let offset = sqlite3_error_offset(self.sqlite3);
154
155 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
156 let offset = 0;
157
158 (
159 sqlite3_errcode(self.sqlite3),
160 offset,
161 sqlite3_errmsg(self.sqlite3),
162 None,
163 )
164 };
165
166 sqlite3_finalize(raw_statement);
167
168 if res == 1 && offset >= 0 {
169 let sub_statement_correction =
170 remaining_sql.as_ptr() as usize - sql_start as usize;
171 let err_msg =
172 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
173 .into_owned();
174
175 return Some((err_msg, offset as usize + sub_statement_correction));
176 }
177 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
178 alter_table = None;
179 }
180 }
181 None
182 }
183
184 pub(crate) fn last_error(&self) -> Result<()> {
185 unsafe {
186 let code = sqlite3_errcode(self.sqlite3);
187 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
188 if NON_ERROR_CODES.contains(&code) {
189 return Ok(());
190 }
191
192 let message = sqlite3_errmsg(self.sqlite3);
193 let message = if message.is_null() {
194 None
195 } else {
196 Some(
197 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
198 .into_owned(),
199 )
200 };
201
202 Err(anyhow!(
203 "Sqlite call failed with code {} and message: {:?}",
204 code as isize,
205 message
206 ))
207 }
208 }
209
210 pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
211 *self.write.borrow_mut() = true;
212 let result = callback(self);
213 *self.write.borrow_mut() = false;
214 result
215 }
216}
217
218fn parse_alter_table(remaining_sql_str: &str) -> Option<(String, String)> {
219 let remaining_sql_str = remaining_sql_str.to_lowercase();
220 if remaining_sql_str.starts_with("alter") {
221 if let Some(table_offset) = remaining_sql_str.find("table") {
222 let after_table_offset = table_offset + "table".len();
223 let table_to_alter = remaining_sql_str
224 .chars()
225 .skip(after_table_offset)
226 .skip_while(|c| c.is_whitespace())
227 .take_while(|c| !c.is_whitespace())
228 .collect::<String>();
229 if !table_to_alter.is_empty() {
230 let column_name =
231 if let Some(rename_offset) = remaining_sql_str.find("rename column") {
232 let after_rename_offset = rename_offset + "rename column".len();
233 remaining_sql_str
234 .chars()
235 .skip(after_rename_offset)
236 .skip_while(|c| c.is_whitespace())
237 .take_while(|c| !c.is_whitespace())
238 .collect::<String>()
239 } else if let Some(drop_offset) = remaining_sql_str.find("drop column") {
240 let after_drop_offset = drop_offset + "drop column".len();
241 remaining_sql_str
242 .chars()
243 .skip(after_drop_offset)
244 .skip_while(|c| c.is_whitespace())
245 .take_while(|c| !c.is_whitespace())
246 .collect::<String>()
247 } else {
248 "__place_holder_column_for_syntax_checking".to_string()
249 };
250 return Some((table_to_alter, column_name));
251 }
252 }
253 }
254 None
255}
256
257impl Drop for Connection {
258 fn drop(&mut self) {
259 unsafe { sqlite3_close(self.sqlite3) };
260 }
261}
262
263#[cfg(test)]
264mod test {
265 use anyhow::Result;
266 use indoc::indoc;
267
268 use crate::connection::Connection;
269
270 #[test]
271 fn string_round_trips() -> Result<()> {
272 let connection = Connection::open_memory(Some("string_round_trips"));
273 connection
274 .exec(indoc! {"
275 CREATE TABLE text (
276 text TEXT
277 );"})
278 .unwrap()()
279 .unwrap();
280
281 let text = "Some test text";
282
283 connection
284 .exec_bound("INSERT INTO text (text) VALUES (?);")
285 .unwrap()(text)
286 .unwrap();
287
288 assert_eq!(
289 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
290 Some(text.to_string())
291 );
292
293 Ok(())
294 }
295
296 #[test]
297 fn tuple_round_trips() {
298 let connection = Connection::open_memory(Some("tuple_round_trips"));
299 connection
300 .exec(indoc! {"
301 CREATE TABLE test (
302 text TEXT,
303 integer INTEGER,
304 blob BLOB
305 );"})
306 .unwrap()()
307 .unwrap();
308
309 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
310 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
311
312 let mut insert = connection
313 .exec_bound::<(String, usize, Vec<u8>)>(
314 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
315 )
316 .unwrap();
317
318 insert(tuple1.clone()).unwrap();
319 insert(tuple2.clone()).unwrap();
320
321 assert_eq!(
322 connection
323 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
324 .unwrap()()
325 .unwrap(),
326 vec![tuple1, tuple2]
327 );
328 }
329
330 #[test]
331 fn bool_round_trips() {
332 let connection = Connection::open_memory(Some("bool_round_trips"));
333 connection
334 .exec(indoc! {"
335 CREATE TABLE bools (
336 t INTEGER,
337 f INTEGER
338 );"})
339 .unwrap()()
340 .unwrap();
341
342 connection
343 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
344 .unwrap()((true, false))
345 .unwrap();
346
347 assert_eq!(
348 connection
349 .select_row::<(bool, bool)>("SELECT * FROM bools;")
350 .unwrap()()
351 .unwrap(),
352 Some((true, false))
353 );
354 }
355
356 #[test]
357 fn backup_works() {
358 let connection1 = Connection::open_memory(Some("backup_works"));
359 connection1
360 .exec(indoc! {"
361 CREATE TABLE blobs (
362 data BLOB
363 );"})
364 .unwrap()()
365 .unwrap();
366 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
367 connection1
368 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
369 .unwrap()(blob.clone())
370 .unwrap();
371
372 // Backup connection1 to connection2
373 let connection2 = Connection::open_memory(Some("backup_works_other"));
374 connection1.backup_main(&connection2).unwrap();
375
376 // Delete the added blob and verify its deleted on the other side
377 let read_blobs = connection1
378 .select::<Vec<u8>>("SELECT * FROM blobs;")
379 .unwrap()()
380 .unwrap();
381 assert_eq!(read_blobs, vec![blob]);
382 }
383
384 #[test]
385 fn multi_step_statement_works() {
386 let connection = Connection::open_memory(Some("multi_step_statement_works"));
387
388 connection
389 .exec(indoc! {"
390 CREATE TABLE test (
391 col INTEGER
392 )"})
393 .unwrap()()
394 .unwrap();
395
396 connection
397 .exec(indoc! {"
398 INSERT INTO test(col) VALUES (2)"})
399 .unwrap()()
400 .unwrap();
401
402 assert_eq!(
403 connection
404 .select_row::<usize>("SELECT * FROM test")
405 .unwrap()()
406 .unwrap(),
407 Some(2)
408 );
409 }
410
411 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
412 #[test]
413 fn test_sql_has_syntax_errors() {
414 let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
415 let first_stmt =
416 "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
417 let second_stmt = "SELECT FROM";
418
419 let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
420
421 let res = connection
422 .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
423 .map(|(_, offset)| offset);
424
425 assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
426 }
427
428 #[test]
429 fn test_alter_table_syntax() {
430 let connection = Connection::open_memory(Some("test_alter_table_syntax"));
431
432 assert!(
433 connection
434 .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
435 .is_none()
436 );
437
438 assert!(
439 connection
440 .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
441 .is_some()
442 );
443 }
444}