connection.rs

  1use std::{
  2    ffi::{CStr, CString},
  3    marker::PhantomData,
  4};
  5
  6use anyhow::{anyhow, Result};
  7use libsqlite3_sys::*;
  8
  9use crate::statement::Statement;
 10
 11pub struct Connection {
 12    pub(crate) sqlite3: *mut sqlite3,
 13    persistent: bool,
 14    phantom: PhantomData<sqlite3>,
 15}
 16unsafe impl Send for Connection {}
 17
 18impl Connection {
 19    fn open(uri: &str, persistent: bool) -> Result<Self> {
 20        let mut connection = Self {
 21            sqlite3: 0 as *mut _,
 22            persistent,
 23            phantom: PhantomData,
 24        };
 25
 26        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 27        unsafe {
 28            sqlite3_open_v2(
 29                CString::new(uri)?.as_ptr(),
 30                &mut connection.sqlite3,
 31                flags,
 32                0 as *const _,
 33            );
 34
 35            connection.last_error()?;
 36        }
 37
 38        Ok(connection)
 39    }
 40
 41    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 42    /// instead.
 43    pub fn open_file(uri: &str) -> Self {
 44        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri))
 45    }
 46
 47    pub fn open_memory(uri: &str) -> Self {
 48        let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
 49        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 50    }
 51
 52    pub fn persistent(&self) -> bool {
 53        self.persistent
 54    }
 55
 56    pub(crate) fn last_insert_id(&self) -> i64 {
 57        unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
 58    }
 59
 60    pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
 61        self.exec(query)?;
 62        Ok(self.last_insert_id())
 63    }
 64
 65    pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
 66        unsafe {
 67            sqlite3_exec(
 68                self.sqlite3,
 69                CString::new(query.as_ref())?.as_ptr(),
 70                None,
 71                0 as *mut _,
 72                0 as *mut _,
 73            );
 74            self.last_error()?;
 75        }
 76        Ok(())
 77    }
 78
 79    pub fn prepare<T: AsRef<str>>(&self, query: T) -> Result<Statement> {
 80        Statement::prepare(&self, query)
 81    }
 82
 83    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 84        unsafe {
 85            let backup = sqlite3_backup_init(
 86                destination.sqlite3,
 87                CString::new("main")?.as_ptr(),
 88                self.sqlite3,
 89                CString::new("main")?.as_ptr(),
 90            );
 91            sqlite3_backup_step(backup, -1);
 92            sqlite3_backup_finish(backup);
 93            destination.last_error()
 94        }
 95    }
 96
 97    pub(crate) fn last_error(&self) -> Result<()> {
 98        const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
 99        unsafe {
100            let code = sqlite3_errcode(self.sqlite3);
101            if NON_ERROR_CODES.contains(&code) {
102                return Ok(());
103            }
104
105            let message = sqlite3_errmsg(self.sqlite3);
106            let message = if message.is_null() {
107                None
108            } else {
109                Some(
110                    String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
111                        .into_owned(),
112                )
113            };
114
115            Err(anyhow!(
116                "Sqlite call failed with code {} and message: {:?}",
117                code as isize,
118                message
119            ))
120        }
121    }
122}
123
124impl Drop for Connection {
125    fn drop(&mut self) {
126        unsafe { sqlite3_close(self.sqlite3) };
127    }
128}
129
130#[cfg(test)]
131mod test {
132    use anyhow::Result;
133    use indoc::indoc;
134
135    use crate::connection::Connection;
136
137    #[test]
138    fn string_round_trips() -> Result<()> {
139        let connection = Connection::open_memory("string_round_trips");
140        connection
141            .exec(indoc! {"
142            CREATE TABLE text (
143                text TEXT
144            );"})
145            .unwrap();
146
147        let text = "Some test text";
148
149        connection
150            .prepare("INSERT INTO text (text) VALUES (?);")
151            .unwrap()
152            .with_bindings(text)
153            .unwrap()
154            .exec()
155            .unwrap();
156
157        assert_eq!(
158            &connection
159                .prepare("SELECT text FROM text;")
160                .unwrap()
161                .row::<String>()
162                .unwrap(),
163            text
164        );
165
166        Ok(())
167    }
168
169    #[test]
170    fn tuple_round_trips() {
171        let connection = Connection::open_memory("tuple_round_trips");
172        connection
173            .exec(indoc! {"
174                CREATE TABLE test (
175                    text TEXT,
176                    integer INTEGER,
177                    blob BLOB
178                );"})
179            .unwrap();
180
181        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
182        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
183
184        let mut insert = connection
185            .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
186            .unwrap();
187
188        insert
189            .with_bindings(tuple1.clone())
190            .unwrap()
191            .exec()
192            .unwrap();
193        insert
194            .with_bindings(tuple2.clone())
195            .unwrap()
196            .exec()
197            .unwrap();
198
199        assert_eq!(
200            connection
201                .prepare("SELECT * FROM test")
202                .unwrap()
203                .rows::<(String, usize, Vec<u8>)>()
204                .unwrap(),
205            vec![tuple1, tuple2]
206        );
207    }
208
209    #[test]
210    fn backup_works() {
211        let connection1 = Connection::open_memory("backup_works");
212        connection1
213            .exec(indoc! {"
214                CREATE TABLE blobs (
215                    data BLOB
216                );"})
217            .unwrap();
218        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
219        let mut write = connection1
220            .prepare("INSERT INTO blobs (data) VALUES (?);")
221            .unwrap();
222        write.bind_blob(1, blob).unwrap();
223        write.exec().unwrap();
224
225        // Backup connection1 to connection2
226        let connection2 = Connection::open_memory("backup_works_other");
227        connection1.backup_main(&connection2).unwrap();
228
229        // Delete the added blob and verify its deleted on the other side
230        let read_blobs = connection1
231            .prepare("SELECT * FROM blobs;")
232            .unwrap()
233            .rows::<Vec<u8>>()
234            .unwrap();
235        assert_eq!(read_blobs, vec![blob]);
236    }
237}