connection.rs

  1use std::{
  2    ffi::{CStr, CString},
  3    marker::PhantomData,
  4};
  5
  6use anyhow::{anyhow, Result};
  7use libsqlite3_sys::*;
  8
  9pub struct Connection {
 10    pub(crate) sqlite3: *mut sqlite3,
 11    persistent: bool,
 12    phantom: PhantomData<sqlite3>,
 13}
 14unsafe impl Send for Connection {}
 15
 16impl Connection {
 17    fn open(uri: &str, persistent: bool) -> Result<Self> {
 18        let mut connection = Self {
 19            sqlite3: 0 as *mut _,
 20            persistent,
 21            phantom: PhantomData,
 22        };
 23
 24        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 25        unsafe {
 26            sqlite3_open_v2(
 27                CString::new(uri)?.as_ptr(),
 28                &mut connection.sqlite3,
 29                flags,
 30                0 as *const _,
 31            );
 32
 33            // Turn on extended error codes
 34            sqlite3_extended_result_codes(connection.sqlite3, 1);
 35
 36            connection.last_error()?;
 37        }
 38
 39        Ok(connection)
 40    }
 41
 42    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 43    /// instead.
 44    pub fn open_file(uri: &str) -> Self {
 45        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
 46    }
 47
 48    pub fn open_memory(uri: Option<&str>) -> Self {
 49        let in_memory_path = if let Some(uri) = uri {
 50            format!("file:{}?mode=memory&cache=shared", uri)
 51        } else {
 52            ":memory:".to_string()
 53        };
 54
 55        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 56    }
 57
 58    pub fn persistent(&self) -> bool {
 59        self.persistent
 60    }
 61
 62    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 63        unsafe {
 64            let backup = sqlite3_backup_init(
 65                destination.sqlite3,
 66                CString::new("main")?.as_ptr(),
 67                self.sqlite3,
 68                CString::new("main")?.as_ptr(),
 69            );
 70            sqlite3_backup_step(backup, -1);
 71            sqlite3_backup_finish(backup);
 72            destination.last_error()
 73        }
 74    }
 75
 76    pub(crate) fn last_error(&self) -> Result<()> {
 77        unsafe {
 78            let code = sqlite3_errcode(self.sqlite3);
 79            const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
 80            if NON_ERROR_CODES.contains(&code) {
 81                return Ok(());
 82            }
 83
 84            let message = sqlite3_errmsg(self.sqlite3);
 85            let message = if message.is_null() {
 86                None
 87            } else {
 88                Some(
 89                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
 90                        .into_owned(),
 91                )
 92            };
 93
 94            Err(anyhow!(
 95                "Sqlite call failed with code {} and message: {:?}",
 96                code as isize,
 97                message
 98            ))
 99        }
100    }
101}
102
103impl Drop for Connection {
104    fn drop(&mut self) {
105        unsafe { sqlite3_close(self.sqlite3) };
106    }
107}
108
109#[cfg(test)]
110mod test {
111    use anyhow::Result;
112    use indoc::indoc;
113
114    use crate::connection::Connection;
115
116    #[test]
117    fn string_round_trips() -> Result<()> {
118        let connection = Connection::open_memory(Some("string_round_trips"));
119        connection
120            .exec(indoc! {"
121            CREATE TABLE text (
122                text TEXT
123            );"})
124            .unwrap()()
125        .unwrap();
126
127        let text = "Some test text";
128
129        connection
130            .exec_bound("INSERT INTO text (text) VALUES (?);")
131            .unwrap()(text)
132        .unwrap();
133
134        assert_eq!(
135            connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
136            Some(text.to_string())
137        );
138
139        Ok(())
140    }
141
142    #[test]
143    fn tuple_round_trips() {
144        let connection = Connection::open_memory(Some("tuple_round_trips"));
145        connection
146            .exec(indoc! {"
147                CREATE TABLE test (
148                    text TEXT,
149                    integer INTEGER,
150                    blob BLOB
151                );"})
152            .unwrap()()
153        .unwrap();
154
155        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
156        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
157
158        let mut insert = connection
159            .exec_bound::<(String, usize, Vec<u8>)>(
160                "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
161            )
162            .unwrap();
163
164        insert(tuple1.clone()).unwrap();
165        insert(tuple2.clone()).unwrap();
166
167        assert_eq!(
168            connection
169                .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
170                .unwrap()()
171            .unwrap(),
172            vec![tuple1, tuple2]
173        );
174    }
175
176    #[test]
177    fn bool_round_trips() {
178        let connection = Connection::open_memory(Some("bool_round_trips"));
179        connection
180            .exec(indoc! {"
181                CREATE TABLE bools (
182                    t INTEGER,
183                    f INTEGER
184                );"})
185            .unwrap()()
186        .unwrap();
187
188        connection
189            .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
190            .unwrap()((true, false))
191        .unwrap();
192
193        assert_eq!(
194            connection
195                .select_row::<(bool, bool)>("SELECT * FROM bools;")
196                .unwrap()()
197            .unwrap(),
198            Some((true, false))
199        );
200    }
201
202    #[test]
203    fn backup_works() {
204        let connection1 = Connection::open_memory(Some("backup_works"));
205        connection1
206            .exec(indoc! {"
207                CREATE TABLE blobs (
208                    data BLOB
209                );"})
210            .unwrap()()
211        .unwrap();
212        let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
213        connection1
214            .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
215            .unwrap()(blob.clone())
216        .unwrap();
217
218        // Backup connection1 to connection2
219        let connection2 = Connection::open_memory(Some("backup_works_other"));
220        connection1.backup_main(&connection2).unwrap();
221
222        // Delete the added blob and verify its deleted on the other side
223        let read_blobs = connection1
224            .select::<Vec<u8>>("SELECT * FROM blobs;")
225            .unwrap()()
226        .unwrap();
227        assert_eq!(read_blobs, vec![blob]);
228    }
229
230    #[test]
231    fn multi_step_statement_works() {
232        let connection = Connection::open_memory(Some("multi_step_statement_works"));
233
234        connection
235            .exec(indoc! {"
236                CREATE TABLE test (
237                    col INTEGER
238                )"})
239            .unwrap()()
240        .unwrap();
241
242        connection
243            .exec(indoc! {"
244            INSERT INTO test(col) VALUES (2)"})
245            .unwrap()()
246        .unwrap();
247
248        assert_eq!(
249            connection
250                .select_row::<usize>("SELECt * FROM test")
251                .unwrap()()
252            .unwrap(),
253            Some(2)
254        );
255    }
256}