connection.rs

  1use std::{
  2    ffi::{CStr, CString},
  3    marker::PhantomData,
  4    path::Path,
  5    ptr,
  6};
  7
  8use anyhow::{anyhow, Result};
  9use libsqlite3_sys::*;
 10
 11pub struct Connection {
 12    pub(crate) sqlite3: *mut sqlite3,
 13    persistent: bool,
 14    pub(crate) write: bool,
 15    _sqlite: PhantomData<sqlite3>,
 16}
 17unsafe impl Send for Connection {}
 18
 19impl Connection {
 20    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 21        let mut connection = Self {
 22            sqlite3: 0 as *mut _,
 23            persistent,
 24            write: true,
 25            _sqlite: PhantomData,
 26        };
 27
 28        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 29        unsafe {
 30            sqlite3_open_v2(
 31                CString::new(uri)?.as_ptr(),
 32                &mut connection.sqlite3,
 33                flags,
 34                0 as *const _,
 35            );
 36
 37            // Turn on extended error codes
 38            sqlite3_extended_result_codes(connection.sqlite3, 1);
 39
 40            connection.last_error()?;
 41        }
 42
 43        Ok(connection)
 44    }
 45
 46    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 47    /// instead.
 48    pub fn open_file(uri: &str) -> Self {
 49        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 50    }
 51
 52    pub fn open_memory(uri: Option<&str>) -> Self {
 53        let in_memory_path = if let Some(uri) = uri {
 54            format!("file:{}?mode=memory&cache=shared", uri)
 55        } else {
 56            ":memory:".to_string()
 57        };
 58
 59        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 60    }
 61
 62    pub fn persistent(&self) -> bool {
 63        self.persistent
 64    }
 65
 66    pub fn can_write(&self) -> bool {
 67        self.write
 68    }
 69
 70    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 71        unsafe {
 72            let backup = sqlite3_backup_init(
 73                destination.sqlite3,
 74                CString::new("main")?.as_ptr(),
 75                self.sqlite3,
 76                CString::new("main")?.as_ptr(),
 77            );
 78            sqlite3_backup_step(backup, -1);
 79            sqlite3_backup_finish(backup);
 80            destination.last_error()
 81        }
 82    }
 83
 84    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 85        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 86        self.backup_main(&destination)
 87    }
 88
 89    pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
 90        let sql = CString::new(sql).unwrap();
 91        let mut remaining_sql = sql.as_c_str();
 92        let sql_start = remaining_sql.as_ptr();
 93
 94        unsafe {
 95            while {
 96                let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
 97                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 98            } {
 99                let mut raw_statement = 0 as *mut sqlite3_stmt;
100                let mut remaining_sql_ptr = ptr::null();
101                sqlite3_prepare_v2(
102                    self.sqlite3,
103                    remaining_sql.as_ptr(),
104                    -1,
105                    &mut raw_statement,
106                    &mut remaining_sql_ptr,
107                );
108
109                let res = sqlite3_errcode(self.sqlite3);
110                let offset = sqlite3_error_offset(self.sqlite3);
111                let message = sqlite3_errmsg(self.sqlite3);
112
113                sqlite3_finalize(raw_statement);
114
115                if res == 1 && offset >= 0 {
116                    let err_msg =
117                        String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
118                            .into_owned();
119                    let sub_statement_correction =
120                        remaining_sql.as_ptr() as usize - sql_start as usize;
121
122                    return Some((err_msg, offset as usize + sub_statement_correction));
123                }
124                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
125            }
126        }
127        None
128    }
129
130    pub(crate) fn last_error(&self) -> Result<()> {
131        unsafe {
132            let code = sqlite3_errcode(self.sqlite3);
133            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
134            if NON_ERROR_CODES.contains(&code) {
135                return Ok(());
136            }
137
138            let message = sqlite3_errmsg(self.sqlite3);
139            let message = if message.is_null() {
140                None
141            } else {
142                Some(
143                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
144                        .into_owned(),
145                )
146            };
147
148            Err(anyhow!(
149                "Sqlite call failed with code {} and message: {:?}",
150                code as isize,
151                message
152            ))
153        }
154    }
155}
156
157impl Drop for Connection {
158    fn drop(&mut self) {
159        unsafe { sqlite3_close(self.sqlite3) };
160    }
161}
162
163#[cfg(test)]
164mod test {
165    use anyhow::Result;
166    use indoc::indoc;
167
168    use crate::connection::Connection;
169
170    #[test]
171    fn string_round_trips() -> Result<()> {
172        let connection = Connection::open_memory(Some("string_round_trips"));
173        connection
174            .exec(indoc! {"
175            CREATE TABLE text (
176                text TEXT
177            );"})
178            .unwrap()()
179        .unwrap();
180
181        let text = "Some test text";
182
183        connection
184            .exec_bound("INSERT INTO text (text) VALUES (?);")
185            .unwrap()(text)
186        .unwrap();
187
188        assert_eq!(
189            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
190            Some(text.to_string())
191        );
192
193        Ok(())
194    }
195
196    #[test]
197    fn tuple_round_trips() {
198        let connection = Connection::open_memory(Some("tuple_round_trips"));
199        connection
200            .exec(indoc! {"
201                CREATE TABLE test (
202                    text TEXT,
203                    integer INTEGER,
204                    blob BLOB
205                );"})
206            .unwrap()()
207        .unwrap();
208
209        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
210        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
211
212        let mut insert = connection
213            .exec_bound::<(String, usize, Vec<u8>)>(
214                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
215            )
216            .unwrap();
217
218        insert(tuple1.clone()).unwrap();
219        insert(tuple2.clone()).unwrap();
220
221        assert_eq!(
222            connection
223                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
224                .unwrap()()
225            .unwrap(),
226            vec![tuple1, tuple2]
227        );
228    }
229
230    #[test]
231    fn bool_round_trips() {
232        let connection = Connection::open_memory(Some("bool_round_trips"));
233        connection
234            .exec(indoc! {"
235                CREATE TABLE bools (
236                    t INTEGER,
237                    f INTEGER
238                );"})
239            .unwrap()()
240        .unwrap();
241
242        connection
243            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
244            .unwrap()((true, false))
245        .unwrap();
246
247        assert_eq!(
248            connection
249                .select_row::<(bool, bool)>("SELECT * FROM bools;")
250                .unwrap()()
251            .unwrap(),
252            Some((true, false))
253        );
254    }
255
256    #[test]
257    fn backup_works() {
258        let connection1 = Connection::open_memory(Some("backup_works"));
259        connection1
260            .exec(indoc! {"
261                CREATE TABLE blobs (
262                    data BLOB
263                );"})
264            .unwrap()()
265        .unwrap();
266        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
267        connection1
268            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
269            .unwrap()(blob.clone())
270        .unwrap();
271
272        // Backup connection1 to connection2
273        let connection2 = Connection::open_memory(Some("backup_works_other"));
274        connection1.backup_main(&connection2).unwrap();
275
276        // Delete the added blob and verify its deleted on the other side
277        let read_blobs = connection1
278            .select::<Vec<u8>>("SELECT * FROM blobs;")
279            .unwrap()()
280        .unwrap();
281        assert_eq!(read_blobs, vec![blob]);
282    }
283
284    #[test]
285    fn multi_step_statement_works() {
286        let connection = Connection::open_memory(Some("multi_step_statement_works"));
287
288        connection
289            .exec(indoc! {"
290                CREATE TABLE test (
291                    col INTEGER
292                )"})
293            .unwrap()()
294        .unwrap();
295
296        connection
297            .exec(indoc! {"
298            INSERT INTO test(col) VALUES (2)"})
299            .unwrap()()
300        .unwrap();
301
302        assert_eq!(
303            connection
304                .select_row::<usize>("SELECT * FROM test")
305                .unwrap()()
306            .unwrap(),
307            Some(2)
308        );
309    }
310
311    #[test]
312    fn test_sql_has_syntax_errors() {
313        let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
314        let first_stmt =
315            "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
316        let second_stmt = "SELECT FROM";
317
318        let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
319
320        let res = connection
321            .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
322            .map(|(_, offset)| offset);
323
324        assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
325    }
326}