connection.rs

  1use std::{
  2    ffi::{CStr, CString},
  3    marker::PhantomData,
  4    path::Path,
  5};
  6
  7use anyhow::{anyhow, Result};
  8use libsqlite3_sys::*;
  9
 10pub struct Connection {
 11    pub(crate) sqlite3: *mut sqlite3,
 12    persistent: bool,
 13    pub(crate) write: bool,
 14    _sqlite: PhantomData<sqlite3>,
 15}
 16unsafe impl Send for Connection {}
 17
 18impl Connection {
 19    pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
 20        let mut connection = Self {
 21            sqlite3: 0 as *mut _,
 22            persistent,
 23            write: true,
 24            _sqlite: PhantomData,
 25        };
 26
 27        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 28        unsafe {
 29            sqlite3_open_v2(
 30                CString::new(uri)?.as_ptr(),
 31                &mut connection.sqlite3,
 32                flags,
 33                0 as *const _,
 34            );
 35
 36            // Turn on extended error codes
 37            sqlite3_extended_result_codes(connection.sqlite3, 1);
 38
 39            connection.last_error()?;
 40        }
 41
 42        Ok(connection)
 43    }
 44
 45    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 46    /// instead.
 47    pub fn open_file(uri: &str) -> Self {
 48        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 49    }
 50
 51    pub fn open_memory(uri: Option<&str>) -> Self {
 52        let in_memory_path = if let Some(uri) = uri {
 53            format!("file:{}?mode=memory&cache=shared", uri)
 54        } else {
 55            ":memory:".to_string()
 56        };
 57
 58        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 59    }
 60
 61    pub fn persistent(&self) -> bool {
 62        self.persistent
 63    }
 64
 65    pub fn can_write(&self) -> bool {
 66        self.write
 67    }
 68
 69    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 70        unsafe {
 71            let backup = sqlite3_backup_init(
 72                destination.sqlite3,
 73                CString::new("main")?.as_ptr(),
 74                self.sqlite3,
 75                CString::new("main")?.as_ptr(),
 76            );
 77            sqlite3_backup_step(backup, -1);
 78            sqlite3_backup_finish(backup);
 79            destination.last_error()
 80        }
 81    }
 82
 83    pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
 84        let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
 85        self.backup_main(&destination)
 86    }
 87
 88    pub(crate) fn last_error(&self) -> Result<()> {
 89        unsafe {
 90            let code = sqlite3_errcode(self.sqlite3);
 91            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
 92            if NON_ERROR_CODES.contains(&code) {
 93                return Ok(());
 94            }
 95
 96            let message = sqlite3_errmsg(self.sqlite3);
 97            let message = if message.is_null() {
 98                None
 99            } else {
100                Some(
101                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
102                        .into_owned(),
103                )
104            };
105
106            Err(anyhow!(
107                "Sqlite call failed with code {} and message: {:?}",
108                code as isize,
109                message
110            ))
111        }
112    }
113}
114
115impl Drop for Connection {
116    fn drop(&mut self) {
117        unsafe { sqlite3_close(self.sqlite3) };
118    }
119}
120
121#[cfg(test)]
122mod test {
123    use anyhow::Result;
124    use indoc::indoc;
125
126    use crate::connection::Connection;
127
128    #[test]
129    fn string_round_trips() -> Result<()> {
130        let connection = Connection::open_memory(Some("string_round_trips"));
131        connection
132            .exec(indoc! {"
133            CREATE TABLE text (
134                text TEXT
135            );"})
136            .unwrap()()
137        .unwrap();
138
139        let text = "Some test text";
140
141        connection
142            .exec_bound("INSERT INTO text (text) VALUES (?);")
143            .unwrap()(text)
144        .unwrap();
145
146        assert_eq!(
147            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
148            Some(text.to_string())
149        );
150
151        Ok(())
152    }
153
154    #[test]
155    fn tuple_round_trips() {
156        let connection = Connection::open_memory(Some("tuple_round_trips"));
157        connection
158            .exec(indoc! {"
159                CREATE TABLE test (
160                    text TEXT,
161                    integer INTEGER,
162                    blob BLOB
163                );"})
164            .unwrap()()
165        .unwrap();
166
167        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
168        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
169
170        let mut insert = connection
171            .exec_bound::<(String, usize, Vec<u8>)>(
172                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
173            )
174            .unwrap();
175
176        insert(tuple1.clone()).unwrap();
177        insert(tuple2.clone()).unwrap();
178
179        assert_eq!(
180            connection
181                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
182                .unwrap()()
183            .unwrap(),
184            vec![tuple1, tuple2]
185        );
186    }
187
188    #[test]
189    fn bool_round_trips() {
190        let connection = Connection::open_memory(Some("bool_round_trips"));
191        connection
192            .exec(indoc! {"
193                CREATE TABLE bools (
194                    t INTEGER,
195                    f INTEGER
196                );"})
197            .unwrap()()
198        .unwrap();
199
200        connection
201            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
202            .unwrap()((true, false))
203        .unwrap();
204
205        assert_eq!(
206            connection
207                .select_row::<(bool, bool)>("SELECT * FROM bools;")
208                .unwrap()()
209            .unwrap(),
210            Some((true, false))
211        );
212    }
213
214    #[test]
215    fn backup_works() {
216        let connection1 = Connection::open_memory(Some("backup_works"));
217        connection1
218            .exec(indoc! {"
219                CREATE TABLE blobs (
220                    data BLOB
221                );"})
222            .unwrap()()
223        .unwrap();
224        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
225        connection1
226            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
227            .unwrap()(blob.clone())
228        .unwrap();
229
230        // Backup connection1 to connection2
231        let connection2 = Connection::open_memory(Some("backup_works_other"));
232        connection1.backup_main(&connection2).unwrap();
233
234        // Delete the added blob and verify its deleted on the other side
235        let read_blobs = connection1
236            .select::<Vec<u8>>("SELECT * FROM blobs;")
237            .unwrap()()
238        .unwrap();
239        assert_eq!(read_blobs, vec![blob]);
240    }
241
242    #[test]
243    fn multi_step_statement_works() {
244        let connection = Connection::open_memory(Some("multi_step_statement_works"));
245
246        connection
247            .exec(indoc! {"
248                CREATE TABLE test (
249                    col INTEGER
250                )"})
251            .unwrap()()
252        .unwrap();
253
254        connection
255            .exec(indoc! {"
256            INSERT INTO test(col) VALUES (2)"})
257            .unwrap()()
258        .unwrap();
259
260        assert_eq!(
261            connection
262                .select_row::<usize>("SELECt * FROM test")
263                .unwrap()()
264            .unwrap(),
265            Some(2)
266        );
267    }
268}