connection.rs

  1use std::{
  2    cell::RefCell,
  3    ffi::{CStr, CString},
  4    marker::PhantomData,
  5    path::Path,
  6    ptr,
  7};
  8
  9use anyhow::{anyhow, Result};
 10use libsqlite3_sys::*;
 11
 12pub struct Connection {
 13    pub(crate) sqlite3: *mut sqlite3,
 14    persistent: bool,
 15    pub(crate) write: RefCell<bool>,
 16    _sqlite: PhantomData<sqlite3>,
 17}
 18unsafe impl Send for Connection {}
 19
 20impl Connection {
 21    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 22        let mut connection = Self {
 23            sqlite3: 0 as *mut _,
 24            persistent,
 25            write: RefCell::new(true),
 26            _sqlite: PhantomData,
 27        };
 28
 29        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 30        unsafe {
 31            sqlite3_open_v2(
 32                CString::new(uri)?.as_ptr(),
 33                &mut connection.sqlite3,
 34                flags,
 35                0 as *const _,
 36            );
 37
 38            // Turn on extended error codes
 39            sqlite3_extended_result_codes(connection.sqlite3, 1);
 40
 41            connection.last_error()?;
 42        }
 43
 44        Ok(connection)
 45    }
 46
 47    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 48    /// instead.
 49    pub fn open_file(uri: &str) -> Self {
 50        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 51    }
 52
 53    pub fn open_memory(uri: Option<&str>) -> Self {
 54        let in_memory_path = if let Some(uri) = uri {
 55            format!("file:{}?mode=memory&cache=shared", uri)
 56        } else {
 57            ":memory:".to_string()
 58        };
 59
 60        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 61    }
 62
 63    pub fn persistent(&self) -> bool {
 64        self.persistent
 65    }
 66
 67    pub fn can_write(&self) -> bool {
 68        *self.write.borrow()
 69    }
 70
 71    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 72        unsafe {
 73            let backup = sqlite3_backup_init(
 74                destination.sqlite3,
 75                CString::new("main")?.as_ptr(),
 76                self.sqlite3,
 77                CString::new("main")?.as_ptr(),
 78            );
 79            sqlite3_backup_step(backup, -1);
 80            sqlite3_backup_finish(backup);
 81            destination.last_error()
 82        }
 83    }
 84
 85    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 86        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 87        self.backup_main(&destination)
 88    }
 89
 90    pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
 91        let sql = CString::new(sql).unwrap();
 92        let mut remaining_sql = sql.as_c_str();
 93        let sql_start = remaining_sql.as_ptr();
 94
 95        unsafe {
 96            while {
 97                let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
 98                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 99            } {
100                let mut raw_statement = 0 as *mut sqlite3_stmt;
101                let mut remaining_sql_ptr = ptr::null();
102                sqlite3_prepare_v2(
103                    self.sqlite3,
104                    remaining_sql.as_ptr(),
105                    -1,
106                    &mut raw_statement,
107                    &mut remaining_sql_ptr,
108                );
109
110                let res = sqlite3_errcode(self.sqlite3);
111                let offset = sqlite3_error_offset(self.sqlite3);
112                let message = sqlite3_errmsg(self.sqlite3);
113
114                sqlite3_finalize(raw_statement);
115
116                if res == 1 && offset >= 0 {
117                    let err_msg =
118                        String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
119                            .into_owned();
120                    let sub_statement_correction =
121                        remaining_sql.as_ptr() as usize - sql_start as usize;
122
123                    return Some((err_msg, offset as usize + sub_statement_correction));
124                }
125                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
126            }
127        }
128        None
129    }
130
131    pub(crate) fn last_error(&self) -> Result<()> {
132        unsafe {
133            let code = sqlite3_errcode(self.sqlite3);
134            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
135            if NON_ERROR_CODES.contains(&code) {
136                return Ok(());
137            }
138
139            let message = sqlite3_errmsg(self.sqlite3);
140            let message = if message.is_null() {
141                None
142            } else {
143                Some(
144                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
145                        .into_owned(),
146                )
147            };
148
149            Err(anyhow!(
150                "Sqlite call failed with code {} and message: {:?}",
151                code as isize,
152                message
153            ))
154        }
155    }
156
157    pub(crate) fn with_write<T>(&self, callback: impl FnOnce(&Connection) -> T) -> T {
158        *self.write.borrow_mut() = true;
159        let result = callback(self);
160        *self.write.borrow_mut() = false;
161        result
162    }
163}
164
165impl Drop for Connection {
166    fn drop(&mut self) {
167        unsafe { sqlite3_close(self.sqlite3) };
168    }
169}
170
171#[cfg(test)]
172mod test {
173    use anyhow::Result;
174    use indoc::indoc;
175
176    use crate::connection::Connection;
177
178    #[test]
179    fn string_round_trips() -> Result<()> {
180        let connection = Connection::open_memory(Some("string_round_trips"));
181        connection
182            .exec(indoc! {"
183            CREATE TABLE text (
184                text TEXT
185            );"})
186            .unwrap()()
187        .unwrap();
188
189        let text = "Some test text";
190
191        connection
192            .exec_bound("INSERT INTO text (text) VALUES (?);")
193            .unwrap()(text)
194        .unwrap();
195
196        assert_eq!(
197            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
198            Some(text.to_string())
199        );
200
201        Ok(())
202    }
203
204    #[test]
205    fn tuple_round_trips() {
206        let connection = Connection::open_memory(Some("tuple_round_trips"));
207        connection
208            .exec(indoc! {"
209                CREATE TABLE test (
210                    text TEXT,
211                    integer INTEGER,
212                    blob BLOB
213                );"})
214            .unwrap()()
215        .unwrap();
216
217        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
218        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
219
220        let mut insert = connection
221            .exec_bound::<(String, usize, Vec<u8>)>(
222                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
223            )
224            .unwrap();
225
226        insert(tuple1.clone()).unwrap();
227        insert(tuple2.clone()).unwrap();
228
229        assert_eq!(
230            connection
231                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
232                .unwrap()()
233            .unwrap(),
234            vec![tuple1, tuple2]
235        );
236    }
237
238    #[test]
239    fn bool_round_trips() {
240        let connection = Connection::open_memory(Some("bool_round_trips"));
241        connection
242            .exec(indoc! {"
243                CREATE TABLE bools (
244                    t INTEGER,
245                    f INTEGER
246                );"})
247            .unwrap()()
248        .unwrap();
249
250        connection
251            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
252            .unwrap()((true, false))
253        .unwrap();
254
255        assert_eq!(
256            connection
257                .select_row::<(bool, bool)>("SELECT * FROM bools;")
258                .unwrap()()
259            .unwrap(),
260            Some((true, false))
261        );
262    }
263
264    #[test]
265    fn backup_works() {
266        let connection1 = Connection::open_memory(Some("backup_works"));
267        connection1
268            .exec(indoc! {"
269                CREATE TABLE blobs (
270                    data BLOB
271                );"})
272            .unwrap()()
273        .unwrap();
274        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
275        connection1
276            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
277            .unwrap()(blob.clone())
278        .unwrap();
279
280        // Backup connection1 to connection2
281        let connection2 = Connection::open_memory(Some("backup_works_other"));
282        connection1.backup_main(&connection2).unwrap();
283
284        // Delete the added blob and verify its deleted on the other side
285        let read_blobs = connection1
286            .select::<Vec<u8>>("SELECT * FROM blobs;")
287            .unwrap()()
288        .unwrap();
289        assert_eq!(read_blobs, vec![blob]);
290    }
291
292    #[test]
293    fn multi_step_statement_works() {
294        let connection = Connection::open_memory(Some("multi_step_statement_works"));
295
296        connection
297            .exec(indoc! {"
298                CREATE TABLE test (
299                    col INTEGER
300                )"})
301            .unwrap()()
302        .unwrap();
303
304        connection
305            .exec(indoc! {"
306            INSERT INTO test(col) VALUES (2)"})
307            .unwrap()()
308        .unwrap();
309
310        assert_eq!(
311            connection
312                .select_row::<usize>("SELECT * FROM test")
313                .unwrap()()
314            .unwrap(),
315            Some(2)
316        );
317    }
318
319    #[test]
320    fn test_sql_has_syntax_errors() {
321        let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
322        let first_stmt =
323            "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
324        let second_stmt = "SELECT FROM";
325
326        let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
327
328        let res = connection
329            .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
330            .map(|(_, offset)| offset);
331
332        assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
333    }
334}