connection.rs

  1use std::{
  2    cell::RefCell,
  3    ffi::{CStr, CString},
  4    marker::PhantomData,
  5    path::Path,
  6    ptr,
  7};
  8
  9use anyhow::{anyhow, Result};
 10use libsqlite3_sys::*;
 11
 12pub struct Connection {
 13    pub(crate) sqlite3: *mut sqlite3,
 14    persistent: bool,
 15    pub(crate) write: RefCell<bool>,
 16    _sqlite: PhantomData<sqlite3>,
 17}
 18unsafe impl Send for Connection {}
 19
 20impl Connection {
 21    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 22        let mut connection = Self {
 23            sqlite3: ptr::null_mut(),
 24            persistent,
 25            write: RefCell::new(true),
 26            _sqlite: PhantomData,
 27        };
 28
 29        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 30        unsafe {
 31            sqlite3_open_v2(
 32                CString::new(uri)?.as_ptr(),
 33                &mut connection.sqlite3,
 34                flags,
 35                ptr::null(),
 36            );
 37
 38            // Turn on extended error codes
 39            sqlite3_extended_result_codes(connection.sqlite3, 1);
 40
 41            connection.last_error()?;
 42        }
 43
 44        Ok(connection)
 45    }
 46
 47    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 48    /// instead.
 49    pub fn open_file(uri: &str) -> Self {
 50        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 51    }
 52
 53    pub fn open_memory(uri: Option<&str>) -> Self {
 54        let in_memory_path = if let Some(uri) = uri {
 55            format!("file:{}?mode=memory&cache=shared", uri)
 56        } else {
 57            ":memory:".to_string()
 58        };
 59
 60        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 61    }
 62
 63    pub fn persistent(&self) -> bool {
 64        self.persistent
 65    }
 66
 67    pub fn can_write(&self) -> bool {
 68        *self.write.borrow()
 69    }
 70
 71    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 72        unsafe {
 73            let backup = sqlite3_backup_init(
 74                destination.sqlite3,
 75                CString::new("main")?.as_ptr(),
 76                self.sqlite3,
 77                CString::new("main")?.as_ptr(),
 78            );
 79            sqlite3_backup_step(backup, -1);
 80            sqlite3_backup_finish(backup);
 81            destination.last_error()
 82        }
 83    }
 84
 85    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 86        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 87        self.backup_main(&destination)
 88    }
 89
 90    pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
 91        let sql = CString::new(sql).unwrap();
 92        let mut remaining_sql = sql.as_c_str();
 93        let sql_start = remaining_sql.as_ptr();
 94
 95        unsafe {
 96            let mut alter_table = None;
 97            while {
 98                let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
 99                let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
100                if any_remaining_sql {
101                    alter_table = parse_alter_table(remaining_sql_str);
102                }
103                any_remaining_sql
104            } {
105                let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
106                let mut remaining_sql_ptr = ptr::null();
107
108                let (res, offset, message, _conn) = if let Some(table_to_alter) = alter_table {
109                    // ALTER TABLE is a weird statement. When preparing the statement the table's
110                    // existence is checked *before* syntax checking any other part of the statement.
111                    // Therefore, we need to make sure that the table has been created before calling
112                    // prepare. As we don't want to trash whatever database this is connected to, we
113                    // create a new in-memory DB to test.
114
115                    let temp_connection = Connection::open_memory(None);
116                    //This should always succeed, if it doesn't then you really should know about it
117                    temp_connection
118                        .exec(&format!(
119                        "CREATE TABLE {table_to_alter}(__place_holder_column_for_syntax_checking)"
120                    ))
121                        .unwrap()()
122                    .unwrap();
123
124                    sqlite3_prepare_v2(
125                        temp_connection.sqlite3,
126                        remaining_sql.as_ptr(),
127                        -1,
128                        &mut raw_statement,
129                        &mut remaining_sql_ptr,
130                    );
131
132                    (
133                        sqlite3_errcode(temp_connection.sqlite3),
134                        sqlite3_error_offset(temp_connection.sqlite3),
135                        sqlite3_errmsg(temp_connection.sqlite3),
136                        Some(temp_connection),
137                    )
138                } else {
139                    sqlite3_prepare_v2(
140                        self.sqlite3,
141                        remaining_sql.as_ptr(),
142                        -1,
143                        &mut raw_statement,
144                        &mut remaining_sql_ptr,
145                    );
146                    (
147                        sqlite3_errcode(self.sqlite3),
148                        sqlite3_error_offset(self.sqlite3),
149                        sqlite3_errmsg(self.sqlite3),
150                        None,
151                    )
152                };
153
154                sqlite3_finalize(raw_statement);
155
156                if res == 1 && offset >= 0 {
157                    let sub_statement_correction =
158                        remaining_sql.as_ptr() as usize - sql_start as usize;
159                    let err_msg =
160                        String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
161                            .into_owned();
162
163                    return Some((err_msg, offset as usize + sub_statement_correction));
164                }
165                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
166                alter_table = None;
167            }
168        }
169        None
170    }
171
172    pub(crate) fn last_error(&self) -> Result<()> {
173        unsafe {
174            let code = sqlite3_errcode(self.sqlite3);
175            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
176            if NON_ERROR_CODES.contains(&code) {
177                return Ok(());
178            }
179
180            let message = sqlite3_errmsg(self.sqlite3);
181            let message = if message.is_null() {
182                None
183            } else {
184                Some(
185                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
186                        .into_owned(),
187                )
188            };
189
190            Err(anyhow!(
191                "Sqlite call failed with code {} and message: {:?}",
192                code as isize,
193                message
194            ))
195        }
196    }
197
198    pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
199        *self.write.borrow_mut() = true;
200        let result = callback(self);
201        *self.write.borrow_mut() = false;
202        result
203    }
204}
205
206fn parse_alter_table(remaining_sql_str: &str) -> Option<String> {
207    let remaining_sql_str = remaining_sql_str.to_lowercase();
208    if remaining_sql_str.starts_with("alter") {
209        if let Some(table_offset) = remaining_sql_str.find("table") {
210            let after_table_offset = table_offset + "table".len();
211            let table_to_alter = remaining_sql_str
212                .chars()
213                .skip(after_table_offset)
214                .skip_while(|c| c.is_whitespace())
215                .take_while(|c| !c.is_whitespace())
216                .collect::<String>();
217            if !table_to_alter.is_empty() {
218                return Some(table_to_alter);
219            }
220        }
221    }
222    None
223}
224
225impl Drop for Connection {
226    fn drop(&mut self) {
227        unsafe { sqlite3_close(self.sqlite3) };
228    }
229}
230
231#[cfg(test)]
232mod test {
233    use anyhow::Result;
234    use indoc::indoc;
235
236    use crate::connection::Connection;
237
238    #[test]
239    fn string_round_trips() -> Result<()> {
240        let connection = Connection::open_memory(Some("string_round_trips"));
241        connection
242            .exec(indoc! {"
243            CREATE TABLE text (
244                text TEXT
245            );"})
246            .unwrap()()
247        .unwrap();
248
249        let text = "Some test text";
250
251        connection
252            .exec_bound("INSERT INTO text (text) VALUES (?);")
253            .unwrap()(text)
254        .unwrap();
255
256        assert_eq!(
257            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
258            Some(text.to_string())
259        );
260
261        Ok(())
262    }
263
264    #[test]
265    fn tuple_round_trips() {
266        let connection = Connection::open_memory(Some("tuple_round_trips"));
267        connection
268            .exec(indoc! {"
269                CREATE TABLE test (
270                    text TEXT,
271                    integer INTEGER,
272                    blob BLOB
273                );"})
274            .unwrap()()
275        .unwrap();
276
277        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
278        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
279
280        let mut insert = connection
281            .exec_bound::<(String, usize, Vec<u8>)>(
282                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
283            )
284            .unwrap();
285
286        insert(tuple1.clone()).unwrap();
287        insert(tuple2.clone()).unwrap();
288
289        assert_eq!(
290            connection
291                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
292                .unwrap()()
293            .unwrap(),
294            vec![tuple1, tuple2]
295        );
296    }
297
298    #[test]
299    fn bool_round_trips() {
300        let connection = Connection::open_memory(Some("bool_round_trips"));
301        connection
302            .exec(indoc! {"
303                CREATE TABLE bools (
304                    t INTEGER,
305                    f INTEGER
306                );"})
307            .unwrap()()
308        .unwrap();
309
310        connection
311            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
312            .unwrap()((true, false))
313        .unwrap();
314
315        assert_eq!(
316            connection
317                .select_row::<(bool, bool)>("SELECT * FROM bools;")
318                .unwrap()()
319            .unwrap(),
320            Some((true, false))
321        );
322    }
323
324    #[test]
325    fn backup_works() {
326        let connection1 = Connection::open_memory(Some("backup_works"));
327        connection1
328            .exec(indoc! {"
329                CREATE TABLE blobs (
330                    data BLOB
331                );"})
332            .unwrap()()
333        .unwrap();
334        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
335        connection1
336            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
337            .unwrap()(blob.clone())
338        .unwrap();
339
340        // Backup connection1 to connection2
341        let connection2 = Connection::open_memory(Some("backup_works_other"));
342        connection1.backup_main(&connection2).unwrap();
343
344        // Delete the added blob and verify its deleted on the other side
345        let read_blobs = connection1
346            .select::<Vec<u8>>("SELECT * FROM blobs;")
347            .unwrap()()
348        .unwrap();
349        assert_eq!(read_blobs, vec![blob]);
350    }
351
352    #[test]
353    fn multi_step_statement_works() {
354        let connection = Connection::open_memory(Some("multi_step_statement_works"));
355
356        connection
357            .exec(indoc! {"
358                CREATE TABLE test (
359                    col INTEGER
360                )"})
361            .unwrap()()
362        .unwrap();
363
364        connection
365            .exec(indoc! {"
366            INSERT INTO test(col) VALUES (2)"})
367            .unwrap()()
368        .unwrap();
369
370        assert_eq!(
371            connection
372                .select_row::<usize>("SELECT * FROM test")
373                .unwrap()()
374            .unwrap(),
375            Some(2)
376        );
377    }
378
379    #[test]
380    fn test_sql_has_syntax_errors() {
381        let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
382        let first_stmt =
383            "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
384        let second_stmt = "SELECT FROM";
385
386        let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
387
388        let res = connection
389            .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
390            .map(|(_, offset)| offset);
391
392        assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
393    }
394
395    #[test]
396    fn test_alter_table_syntax() {
397        let connection = Connection::open_memory(Some("test_alter_table_syntax"));
398
399        assert!(connection
400            .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
401            .is_none());
402
403        assert!(connection
404            .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
405            .is_some());
406    }
407}