connection.rs

  1use std::{
  2    cell::RefCell,
  3    ffi::{CStr, CString},
  4    marker::PhantomData,
  5    path::Path,
  6    ptr,
  7};
  8
  9use anyhow::Result;
 10use libsqlite3_sys::*;
 11
 12pub struct Connection {
 13    pub(crate) sqlite3: *mut sqlite3,
 14    persistent: bool,
 15    pub(crate) write: RefCell<bool>,
 16    _sqlite: PhantomData<sqlite3>,
 17}
 18unsafe impl Send for Connection {}
 19
 20impl Connection {
 21    fn open_with_flags(uri: &str, persistent: bool, flags: i32) -> Result<Self> {
 22        let mut connection = Self {
 23            sqlite3: ptr::null_mut(),
 24            persistent,
 25            write: RefCell::new(true),
 26            _sqlite: PhantomData,
 27        };
 28
 29        unsafe {
 30            sqlite3_open_v2(
 31                CString::new(uri)?.as_ptr(),
 32                &mut connection.sqlite3,
 33                flags,
 34                ptr::null(),
 35            );
 36
 37            // Turn on extended error codes
 38            sqlite3_extended_result_codes(connection.sqlite3, 1);
 39
 40            connection.last_error()?;
 41        }
 42
 43        Ok(connection)
 44    }
 45
 46    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 47        Self::open_with_flags(
 48            uri,
 49            persistent,
 50            SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE,
 51        )
 52    }
 53
 54    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 55    /// instead.
 56    pub fn open_file(uri: &str) -> Self {
 57        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 58    }
 59
 60    pub fn open_memory(uri: Option<&str>) -> Self {
 61        if let Some(uri) = uri {
 62            let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
 63            return Self::open_with_flags(
 64                &in_memory_path,
 65                false,
 66                SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE | SQLITE_OPEN_URI,
 67            )
 68            .expect("Could not create fallback in memory db");
 69        } else {
 70            Self::open(":memory:", false).expect("Could not create fallback in memory db")
 71        }
 72    }
 73
 74    pub fn persistent(&self) -> bool {
 75        self.persistent
 76    }
 77
 78    pub fn can_write(&self) -> bool {
 79        *self.write.borrow()
 80    }
 81
 82    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 83        unsafe {
 84            let backup = sqlite3_backup_init(
 85                destination.sqlite3,
 86                CString::new("main")?.as_ptr(),
 87                self.sqlite3,
 88                CString::new("main")?.as_ptr(),
 89            );
 90            sqlite3_backup_step(backup, -1);
 91            sqlite3_backup_finish(backup);
 92            destination.last_error()
 93        }
 94    }
 95
 96    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 97        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 98        self.backup_main(&destination)
 99    }
100
101    pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
102        let sql = CString::new(sql).unwrap();
103        let mut remaining_sql = sql.as_c_str();
104        let sql_start = remaining_sql.as_ptr();
105
106        let mut alter_table = None;
107        while {
108            let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
109            let any_remaining_sql = remaining_sql_str != ";" && !remaining_sql_str.is_empty();
110            if any_remaining_sql {
111                alter_table = parse_alter_table(remaining_sql_str);
112            }
113            any_remaining_sql
114        } {
115            let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
116            let mut remaining_sql_ptr = ptr::null();
117
118            let (res, offset, message, _conn) = if let Some((table_to_alter, column)) = alter_table
119            {
120                // ALTER TABLE is a weird statement. When preparing the statement the table's
121                // existence is checked *before* syntax checking any other part of the statement.
122                // Therefore, we need to make sure that the table has been created before calling
123                // prepare. As we don't want to trash whatever database this is connected to, we
124                // create a new in-memory DB to test.
125
126                let temp_connection = Connection::open_memory(None);
127                //This should always succeed, if it doesn't then you really should know about it
128                temp_connection
129                    .exec(&format!("CREATE TABLE {table_to_alter}({column})"))
130                    .unwrap()()
131                .unwrap();
132
133                unsafe {
134                    sqlite3_prepare_v2(
135                        temp_connection.sqlite3,
136                        remaining_sql.as_ptr(),
137                        -1,
138                        &mut raw_statement,
139                        &mut remaining_sql_ptr,
140                    )
141                };
142
143                #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
144                let offset = unsafe { sqlite3_error_offset(temp_connection.sqlite3) };
145
146                #[cfg(any(target_os = "linux", target_os = "freebsd"))]
147                let offset = 0;
148
149                unsafe {
150                    (
151                        sqlite3_errcode(temp_connection.sqlite3),
152                        offset,
153                        sqlite3_errmsg(temp_connection.sqlite3),
154                        Some(temp_connection),
155                    )
156                }
157            } else {
158                unsafe {
159                    sqlite3_prepare_v2(
160                        self.sqlite3,
161                        remaining_sql.as_ptr(),
162                        -1,
163                        &mut raw_statement,
164                        &mut remaining_sql_ptr,
165                    )
166                };
167
168                #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
169                let offset = unsafe { sqlite3_error_offset(self.sqlite3) };
170
171                #[cfg(any(target_os = "linux", target_os = "freebsd"))]
172                let offset = 0;
173
174                unsafe {
175                    (
176                        sqlite3_errcode(self.sqlite3),
177                        offset,
178                        sqlite3_errmsg(self.sqlite3),
179                        None,
180                    )
181                }
182            };
183
184            unsafe { sqlite3_finalize(raw_statement) };
185
186            if res == 1 && offset >= 0 {
187                let sub_statement_correction = remaining_sql.as_ptr() as usize - sql_start as usize;
188                let err_msg = String::from_utf8_lossy(unsafe {
189                    CStr::from_ptr(message as *const _).to_bytes()
190                })
191                .into_owned();
192
193                return Some((err_msg, offset as usize + sub_statement_correction));
194            }
195            remaining_sql = unsafe { CStr::from_ptr(remaining_sql_ptr) };
196            alter_table = None;
197        }
198        None
199    }
200
201    pub(crate) fn last_error(&self) -> Result<()> {
202        unsafe {
203            let code = sqlite3_errcode(self.sqlite3);
204            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
205            if NON_ERROR_CODES.contains(&code) {
206                return Ok(());
207            }
208
209            let message = sqlite3_errmsg(self.sqlite3);
210            let message = if message.is_null() {
211                None
212            } else {
213                Some(
214                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
215                        .into_owned(),
216                )
217            };
218
219            anyhow::bail!("Sqlite call failed with code {code} and message: {message:?}")
220        }
221    }
222
223    pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
224        *self.write.borrow_mut() = true;
225        let result = callback(self);
226        *self.write.borrow_mut() = false;
227        result
228    }
229}
230
231fn parse_alter_table(remaining_sql_str: &str) -> Option<(String, String)> {
232    let remaining_sql_str = remaining_sql_str.to_lowercase();
233    if remaining_sql_str.starts_with("alter")
234        && let Some(table_offset) = remaining_sql_str.find("table")
235    {
236        let after_table_offset = table_offset + "table".len();
237        let table_to_alter = remaining_sql_str
238            .chars()
239            .skip(after_table_offset)
240            .skip_while(|c| c.is_whitespace())
241            .take_while(|c| !c.is_whitespace())
242            .collect::<String>();
243        if !table_to_alter.is_empty() {
244            let column_name = if let Some(rename_offset) = remaining_sql_str.find("rename column") {
245                let after_rename_offset = rename_offset + "rename column".len();
246                remaining_sql_str
247                    .chars()
248                    .skip(after_rename_offset)
249                    .skip_while(|c| c.is_whitespace())
250                    .take_while(|c| !c.is_whitespace())
251                    .collect::<String>()
252            } else if let Some(drop_offset) = remaining_sql_str.find("drop column") {
253                let after_drop_offset = drop_offset + "drop column".len();
254                remaining_sql_str
255                    .chars()
256                    .skip(after_drop_offset)
257                    .skip_while(|c| c.is_whitespace())
258                    .take_while(|c| !c.is_whitespace())
259                    .collect::<String>()
260            } else {
261                "__place_holder_column_for_syntax_checking".to_string()
262            };
263            return Some((table_to_alter, column_name));
264        }
265    }
266    None
267}
268
269impl Drop for Connection {
270    fn drop(&mut self) {
271        unsafe { sqlite3_close(self.sqlite3) };
272    }
273}
274
275#[cfg(test)]
276mod test {
277    use anyhow::Result;
278    use indoc::indoc;
279    use std::{
280        fs,
281        sync::atomic::{AtomicUsize, Ordering},
282    };
283
284    use crate::connection::Connection;
285
286    static NEXT_NAMED_MEMORY_DB_ID: AtomicUsize = AtomicUsize::new(0);
287
288    fn unique_named_memory_db(prefix: &str) -> String {
289        format!(
290            "{prefix}_{}_{}",
291            std::process::id(),
292            NEXT_NAMED_MEMORY_DB_ID.fetch_add(1, Ordering::Relaxed)
293        )
294    }
295
296    fn literal_named_memory_paths(name: &str) -> [String; 3] {
297        let main = format!("file:{name}?mode=memory&cache=shared");
298        [main.clone(), format!("{main}-wal"), format!("{main}-shm")]
299    }
300
301    struct NamedMemoryPathGuard {
302        paths: [String; 3],
303    }
304
305    impl NamedMemoryPathGuard {
306        fn new(name: &str) -> Self {
307            let paths = literal_named_memory_paths(name);
308            for path in &paths {
309                let _ = fs::remove_file(path);
310            }
311            Self { paths }
312        }
313    }
314
315    impl Drop for NamedMemoryPathGuard {
316        fn drop(&mut self) {
317            for path in &self.paths {
318                let _ = fs::remove_file(path);
319            }
320        }
321    }
322
323    #[test]
324    fn string_round_trips() -> Result<()> {
325        let connection = Connection::open_memory(Some("string_round_trips"));
326        connection
327            .exec(indoc! {"
328            CREATE TABLE text (
329                text TEXT
330            );"})
331            .unwrap()()
332        .unwrap();
333
334        let text = "Some test text";
335
336        connection
337            .exec_bound("INSERT INTO text (text) VALUES (?);")
338            .unwrap()(text)
339        .unwrap();
340
341        assert_eq!(
342            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
343            Some(text.to_string())
344        );
345
346        Ok(())
347    }
348
349    #[test]
350    fn tuple_round_trips() {
351        let connection = Connection::open_memory(Some("tuple_round_trips"));
352        connection
353            .exec(indoc! {"
354                CREATE TABLE test (
355                    text TEXT,
356                    integer INTEGER,
357                    blob BLOB
358                );"})
359            .unwrap()()
360        .unwrap();
361
362        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
363        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
364
365        let mut insert = connection
366            .exec_bound::<(String, usize, Vec<u8>)>(
367                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
368            )
369            .unwrap();
370
371        insert(tuple1.clone()).unwrap();
372        insert(tuple2.clone()).unwrap();
373
374        assert_eq!(
375            connection
376                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
377                .unwrap()()
378            .unwrap(),
379            vec![tuple1, tuple2]
380        );
381    }
382
383    #[test]
384    fn bool_round_trips() {
385        let connection = Connection::open_memory(Some("bool_round_trips"));
386        connection
387            .exec(indoc! {"
388                CREATE TABLE bools (
389                    t INTEGER,
390                    f INTEGER
391                );"})
392            .unwrap()()
393        .unwrap();
394
395        connection
396            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
397            .unwrap()((true, false))
398        .unwrap();
399
400        assert_eq!(
401            connection
402                .select_row::<(bool, bool)>("SELECT * FROM bools;")
403                .unwrap()()
404            .unwrap(),
405            Some((true, false))
406        );
407    }
408
409    #[test]
410    fn backup_works() {
411        let connection1 = Connection::open_memory(Some("backup_works"));
412        connection1
413            .exec(indoc! {"
414                CREATE TABLE blobs (
415                    data BLOB
416                );"})
417            .unwrap()()
418        .unwrap();
419        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
420        connection1
421            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
422            .unwrap()(blob.clone())
423        .unwrap();
424
425        // Backup connection1 to connection2
426        let connection2 = Connection::open_memory(Some("backup_works_other"));
427        connection1.backup_main(&connection2).unwrap();
428
429        // Delete the added blob and verify its deleted on the other side
430        let read_blobs = connection1
431            .select::<Vec<u8>>("SELECT * FROM blobs;")
432            .unwrap()()
433        .unwrap();
434        assert_eq!(read_blobs, vec![blob]);
435    }
436
437    #[test]
438    fn named_memory_connections_do_not_create_literal_backing_files() {
439        let name = unique_named_memory_db("named_memory_connections_do_not_create_backing_files");
440        let guard = NamedMemoryPathGuard::new(&name);
441
442        let connection1 = Connection::open_memory(Some(&name));
443        connection1
444            .exec(indoc! {"
445                CREATE TABLE shared (
446                    value INTEGER
447                )"})
448            .unwrap()()
449        .unwrap();
450        connection1
451            .exec("INSERT INTO shared (value) VALUES (7)")
452            .unwrap()()
453        .unwrap();
454
455        let connection2 = Connection::open_memory(Some(&name));
456        assert_eq!(
457            connection2
458                .select_row::<i64>("SELECT value FROM shared")
459                .unwrap()()
460            .unwrap(),
461            Some(7)
462        );
463
464        for path in &guard.paths {
465            assert!(
466                fs::metadata(path).is_err(),
467                "named in-memory database unexpectedly created backing file {path}"
468            );
469        }
470    }
471
472    #[test]
473    fn multi_step_statement_works() {
474        let connection = Connection::open_memory(Some("multi_step_statement_works"));
475
476        connection
477            .exec(indoc! {"
478                CREATE TABLE test (
479                    col INTEGER
480                )"})
481            .unwrap()()
482        .unwrap();
483
484        connection
485            .exec(indoc! {"
486            INSERT INTO test(col) VALUES (2)"})
487            .unwrap()()
488        .unwrap();
489
490        assert_eq!(
491            connection
492                .select_row::<usize>("SELECT * FROM test")
493                .unwrap()()
494            .unwrap(),
495            Some(2)
496        );
497    }
498
499    #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
500    #[test]
501    fn test_sql_has_syntax_errors() {
502        let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
503        let first_stmt =
504            "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
505        let second_stmt = "SELECT FROM";
506
507        let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
508
509        let res = connection
510            .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
511            .map(|(_, offset)| offset);
512
513        assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
514    }
515
516    #[test]
517    fn test_alter_table_syntax() {
518        let connection = Connection::open_memory(Some("test_alter_table_syntax"));
519
520        assert!(
521            connection
522                .sql_has_syntax_error("ALTER TABLE test ADD x TEXT")
523                .is_none()
524        );
525
526        assert!(
527            connection
528                .sql_has_syntax_error("ALTER TABLE test AAD x TEXT")
529                .is_some()
530        );
531    }
532}