connection.rs

  1use std::{
  2    cell::RefCell,
  3    ffi::{CStr, CString},
  4    marker::PhantomData,
  5    path::Path,
  6    ptr,
  7};
  8
  9use anyhow::Result;
 10use libsqlite3_sys::*;
 11
 12pub struct Connection {
 13    pub(crate) sqlite3: *mut sqlite3,
 14    persistent: bool,
 15    pub(crate) write: RefCell<bool>,
 16    _sqlite: PhantomData<sqlite3>,
 17}
 18unsafe impl Send for Connection {}
 19
 20impl Connection {
 21    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 22        let mut connection = Self {
 23            sqlite3: ptr::null_mut(),
 24            persistent,
 25            write: RefCell::new(true),
 26            _sqlite: PhantomData,
 27        };
 28
 29        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 30        unsafe {
 31            sqlite3_open_v2(
 32                CString::new(uri)?.as_ptr(),
 33                &mut connection.sqlite3,
 34                flags,
 35                ptr::null(),
 36            );
 37
 38            // Turn on extended error codes
 39            sqlite3_extended_result_codes(connection.sqlite3, 1);
 40
 41            connection.last_error()?;
 42        }
 43
 44        Ok(connection)
 45    }
 46
 47    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 48    /// instead.
 49    pub fn open_file(uri: &str) -> Self {
 50        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 51    }
 52
 53    pub fn open_memory(uri: Option<&str>) -> Self {
 54        let in_memory_path = if let Some(uri) = uri {
 55            format!("file:{}?mode=memory&cache=shared", uri)
 56        } else {
 57            ":memory:".to_string()
 58        };
 59
 60        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 61    }
 62
 63    pub const fn persistent(&self) -> bool {
 64        self.persistent
 65    }
 66
 67    pub fn can_write(&self) -> bool {
 68        *self.write.borrow()
 69    }
 70
 71    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 72        unsafe {
 73            let backup = sqlite3_backup_init(
 74                destination.sqlite3,
 75                CString::new("main")?.as_ptr(),
 76                self.sqlite3,
 77                CString::new("main")?.as_ptr(),
 78            );
 79            sqlite3_backup_step(backup, -1);
 80            sqlite3_backup_finish(backup);
 81            destination.last_error()
 82        }
 83    }
 84
 85    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 86        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 87        self.backup_main(&destination)
 88    }
 89
 90    pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
 91        let sql = CString::new(sql).unwrap();
 92        let mut remaining_sql = sql.as_c_str();
 93        let sql_start = remaining_sql.as_ptr();
 94
 95        let mut alter_table = None;
 96        while {
 97            let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
 98            let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
 99            if any_remaining_sql {
100                alter_table = parse_alter_table(remaining_sql_str);
101            }
102            any_remaining_sql
103        } {
104            let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
105            let mut remaining_sql_ptr = ptr::null();
106
107            let (res, offset, message, _conn) = if let Some((table_to_alter, column)) = alter_table
108            {
109                // ALTER TABLE is a weird statement. When preparing the statement the table's
110                // existence is checked *before* syntax checking any other part of the statement.
111                // Therefore, we need to make sure that the table has been created before calling
112                // prepare. As we don't want to trash whatever database this is connected to, we
113                // create a new in-memory DB to test.
114
115                let temp_connection = Connection::open_memory(None);
116                //This should always succeed, if it doesn't then you really should know about it
117                temp_connection
118                    .exec(&format!("CREATE TABLE {table_to_alter}({column})"))
119                    .unwrap()()
120                .unwrap();
121
122                unsafe {
123                    sqlite3_prepare_v2(
124                        temp_connection.sqlite3,
125                        remaining_sql.as_ptr(),
126                        -1,
127                        &mut raw_statement,
128                        &mut remaining_sql_ptr,
129                    )
130                };
131
132                #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
133                let offset = unsafe { sqlite3_error_offset(temp_connection.sqlite3) };
134
135                #[cfg(any(target_os = "linux", target_os = "freebsd"))]
136                let offset = 0;
137
138                unsafe {
139                    (
140                        sqlite3_errcode(temp_connection.sqlite3),
141                        offset,
142                        sqlite3_errmsg(temp_connection.sqlite3),
143                        Some(temp_connection),
144                    )
145                }
146            } else {
147                unsafe {
148                    sqlite3_prepare_v2(
149                        self.sqlite3,
150                        remaining_sql.as_ptr(),
151                        -1,
152                        &mut raw_statement,
153                        &mut remaining_sql_ptr,
154                    )
155                };
156
157                #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
158                let offset = unsafe { sqlite3_error_offset(self.sqlite3) };
159
160                #[cfg(any(target_os = "linux", target_os = "freebsd"))]
161                let offset = 0;
162
163                unsafe {
164                    (
165                        sqlite3_errcode(self.sqlite3),
166                        offset,
167                        sqlite3_errmsg(self.sqlite3),
168                        None,
169                    )
170                }
171            };
172
173            unsafe { sqlite3_finalize(raw_statement) };
174
175            if res == 1 && offset >= 0 {
176                let sub_statement_correction = remaining_sql.as_ptr() as usize - sql_start as usize;
177                let err_msg = String::from_utf8_lossy(unsafe {
178                    CStr::from_ptr(message as *const _).to_bytes()
179                })
180                .into_owned();
181
182                return Some((err_msg, offset as usize + sub_statement_correction));
183            }
184            remaining_sql = unsafe { CStr::from_ptr(remaining_sql_ptr) };
185            alter_table = None;
186        }
187        None
188    }
189
190    pub(crate) fn last_error(&self) -> Result<()> {
191        unsafe {
192            let code = sqlite3_errcode(self.sqlite3);
193            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
194            if NON_ERROR_CODES.contains(&code) {
195                return Ok(());
196            }
197
198            let message = sqlite3_errmsg(self.sqlite3);
199            let message = if message.is_null() {
200                None
201            } else {
202                Some(
203                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
204                        .into_owned(),
205                )
206            };
207
208            anyhow::bail!("Sqlite call failed with code {code} and message: {message:?}")
209        }
210    }
211
212    pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
213        *self.write.borrow_mut() = true;
214        let result = callback(self);
215        *self.write.borrow_mut() = false;
216        result
217    }
218}
219
220fn parse_alter_table(remaining_sql_str: &str) -> Option<(String, String)> {
221    let remaining_sql_str = remaining_sql_str.to_lowercase();
222    if remaining_sql_str.starts_with("alter")
223        && let Some(table_offset) = remaining_sql_str.find("table")
224    {
225        let after_table_offset = table_offset + "table".len();
226        let table_to_alter = remaining_sql_str
227            .chars()
228            .skip(after_table_offset)
229            .skip_while(|c| c.is_whitespace())
230            .take_while(|c| !c.is_whitespace())
231            .collect::<String>();
232        if !table_to_alter.is_empty() {
233            let column_name = if let Some(rename_offset) = remaining_sql_str.find("rename column") {
234                let after_rename_offset = rename_offset + "rename column".len();
235                remaining_sql_str
236                    .chars()
237                    .skip(after_rename_offset)
238                    .skip_while(|c| c.is_whitespace())
239                    .take_while(|c| !c.is_whitespace())
240                    .collect::<String>()
241            } else if let Some(drop_offset) = remaining_sql_str.find("drop column") {
242                let after_drop_offset = drop_offset + "drop column".len();
243                remaining_sql_str
244                    .chars()
245                    .skip(after_drop_offset)
246                    .skip_while(|c| c.is_whitespace())
247                    .take_while(|c| !c.is_whitespace())
248                    .collect::<String>()
249            } else {
250                "__place_holder_column_for_syntax_checking".to_string()
251            };
252            return Some((table_to_alter, column_name));
253        }
254    }
255    None
256}
257
258impl Drop for Connection {
259    fn drop(&mut self) {
260        unsafe { sqlite3_close(self.sqlite3) };
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use anyhow::Result;
267    use indoc::indoc;
268
269    use crate::connection::Connection;
270
271    #[test]
272    fn string_round_trips() -> Result<()> {
273        let connection = Connection::open_memory(Some("string_round_trips"));
274        connection
275            .exec(indoc! {"
276            CREATE TABLE text (
277                text TEXT
278            );"})
279            .unwrap()()
280        .unwrap();
281
282        let text = "Some test text";
283
284        connection
285            .exec_bound("INSERT INTO text (text) VALUES (?);")
286            .unwrap()(text)
287        .unwrap();
288
289        assert_eq!(
290            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
291            Some(text.to_string())
292        );
293
294        Ok(())
295    }
296
297    #[test]
298    fn tuple_round_trips() {
299        let connection = Connection::open_memory(Some("tuple_round_trips"));
300        connection
301            .exec(indoc! {"
302                CREATE TABLE test (
303                    text TEXT,
304                    integer INTEGER,
305                    blob BLOB
306                );"})
307            .unwrap()()
308        .unwrap();
309
310        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
311        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
312
313        let mut insert = connection
314            .exec_bound::<(String, usize, Vec<u8>)>(
315                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
316            )
317            .unwrap();
318
319        insert(tuple1.clone()).unwrap();
320        insert(tuple2.clone()).unwrap();
321
322        assert_eq!(
323            connection
324                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
325                .unwrap()()
326            .unwrap(),
327            vec![tuple1, tuple2]
328        );
329    }
330
331    #[test]
332    fn bool_round_trips() {
333        let connection = Connection::open_memory(Some("bool_round_trips"));
334        connection
335            .exec(indoc! {"
336                CREATE TABLE bools (
337                    t INTEGER,
338                    f INTEGER
339                );"})
340            .unwrap()()
341        .unwrap();
342
343        connection
344            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
345            .unwrap()((true, false))
346        .unwrap();
347
348        assert_eq!(
349            connection
350                .select_row::<(bool, bool)>("SELECT * FROM bools;")
351                .unwrap()()
352            .unwrap(),
353            Some((true, false))
354        );
355    }
356
357    #[test]
358    fn backup_works() {
359        let connection1 = Connection::open_memory(Some("backup_works"));
360        connection1
361            .exec(indoc! {"
362                CREATE TABLE blobs (
363                    data BLOB
364                );"})
365            .unwrap()()
366        .unwrap();
367        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
368        connection1
369            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
370            .unwrap()(blob.clone())
371        .unwrap();
372
373        // Backup connection1 to connection2
374        let connection2 = Connection::open_memory(Some("backup_works_other"));
375        connection1.backup_main(&connection2).unwrap();
376
377        // Delete the added blob and verify its deleted on the other side
378        let read_blobs = connection1
379            .select::<Vec<u8>>("SELECT * FROM blobs;")
380            .unwrap()()
381        .unwrap();
382        assert_eq!(read_blobs, vec![blob]);
383    }
384
385    #[test]
386    fn multi_step_statement_works() {
387        let connection = Connection::open_memory(Some("multi_step_statement_works"));
388
389        connection
390            .exec(indoc! {"
391                CREATE TABLE test (
392                    col INTEGER
393                )"})
394            .unwrap()()
395        .unwrap();
396
397        connection
398            .exec(indoc! {"
399            INSERT INTO test(col) VALUES (2)"})
400            .unwrap()()
401        .unwrap();
402
403        assert_eq!(
404            connection
405                .select_row::<usize>("SELECT * FROM test")
406                .unwrap()()
407            .unwrap(),
408            Some(2)
409        );
410    }
411
412    #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
413    #[test]
414    fn test_sql_has_syntax_errors() {
415        let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
416        let first_stmt =
417            "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
418        let second_stmt = "SELECT FROM";
419
420        let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
421
422        let res = connection
423            .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
424            .map(|(_, offset)| offset);
425
426        assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
427    }
428
429    #[test]
430    fn test_alter_table_syntax() {
431        let connection = Connection::open_memory(Some("test_alter_table_syntax"));
432
433        assert!(
434            connection
435                .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
436                .is_none()
437        );
438
439        assert!(
440            connection
441                .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
442                .is_some()
443        );
444    }
445}