connection.rs

  1use std::{
  2    ffi::{CStr, CString},
  3    marker::PhantomData,
  4};
  5
  6use anyhow::{anyhow, Result};
  7use libsqlite3_sys::*;
  8
  9use crate::statement::Statement;
 10
 11pub struct Connection {
 12    pub(crate) sqlite3: *mut sqlite3,
 13    persistent: bool,
 14    phantom: PhantomData<sqlite3>,
 15}
 16unsafe impl Send for Connection {}
 17
 18impl Connection {
 19    fn open(uri: &str, persistent: bool) -> Result<Self> {
 20        let mut connection = Self {
 21            sqlite3: 0 as *mut _,
 22            persistent,
 23            phantom: PhantomData,
 24        };
 25
 26        let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
 27        unsafe {
 28            sqlite3_open_v2(
 29                CString::new(uri)?.as_ptr(),
 30                &mut connection.sqlite3,
 31                flags,
 32                0 as *const _,
 33            );
 34
 35            // Turn on extended error codes
 36            sqlite3_extended_result_codes(connection.sqlite3, 1);
 37
 38            connection.last_error()?;
 39        }
 40
 41        Ok(connection)
 42    }
 43
 44    /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
 45    /// instead.
 46    pub fn open_file(uri: &str) -> Self {
 47        Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri))
 48    }
 49
 50    pub fn open_memory(uri: &str) -> Self {
 51        let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
 52        Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
 53    }
 54
 55    pub fn persistent(&self) -> bool {
 56        self.persistent
 57    }
 58
 59    pub(crate) fn last_insert_id(&self) -> i64 {
 60        unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
 61    }
 62
 63    pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
 64        self.exec(query)?;
 65        Ok(self.last_insert_id())
 66    }
 67
 68    pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
 69        unsafe {
 70            sqlite3_exec(
 71                self.sqlite3,
 72                CString::new(query.as_ref())?.as_ptr(),
 73                None,
 74                0 as *mut _,
 75                0 as *mut _,
 76            );
 77            sqlite3_errcode(self.sqlite3);
 78            self.last_error()?;
 79        }
 80        Ok(())
 81    }
 82
 83    pub fn prepare<T: AsRef<str>>(&self, query: T) -> Result<Statement> {
 84        Statement::prepare(&self, query)
 85    }
 86
 87    pub fn backup_main(&self, destination: &Connection) -> Result<()> {
 88        unsafe {
 89            let backup = sqlite3_backup_init(
 90                destination.sqlite3,
 91                CString::new("main")?.as_ptr(),
 92                self.sqlite3,
 93                CString::new("main")?.as_ptr(),
 94            );
 95            sqlite3_backup_step(backup, -1);
 96            sqlite3_backup_finish(backup);
 97            destination.last_error()
 98        }
 99    }
100
101    pub(crate) fn last_error(&self) -> Result<()> {
102        unsafe { error_to_result(sqlite3_errcode(self.sqlite3)) }
103    }
104}
105
106impl Drop for Connection {
107    fn drop(&mut self) {
108        unsafe { sqlite3_close(self.sqlite3) };
109    }
110}
111
112pub(crate) fn error_to_result(code: std::os::raw::c_int) -> Result<()> {
113    const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
114    unsafe {
115        if NON_ERROR_CODES.contains(&code) {
116            return Ok(());
117        }
118
119        let message = sqlite3_errstr(code);
120        let message = if message.is_null() {
121            None
122        } else {
123            Some(
124                String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
125                    .into_owned(),
126            )
127        };
128
129        Err(anyhow!(
130            "Sqlite call failed with code {} and message: {:?}",
131            code as isize,
132            message
133        ))
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use anyhow::Result;
140    use indoc::indoc;
141
142    use crate::{connection::Connection, migrations::Migration};
143
144    #[test]
145    fn string_round_trips() -> Result<()> {
146        let connection = Connection::open_memory("string_round_trips");
147        connection
148            .exec(indoc! {"
149            CREATE TABLE text (
150                text TEXT
151            );"})
152            .unwrap();
153
154        let text = "Some test text";
155
156        connection
157            .prepare("INSERT INTO text (text) VALUES (?);")
158            .unwrap()
159            .with_bindings(text)
160            .unwrap()
161            .exec()
162            .unwrap();
163
164        assert_eq!(
165            &connection
166                .prepare("SELECT text FROM text;")
167                .unwrap()
168                .row::<String>()
169                .unwrap(),
170            text
171        );
172
173        Ok(())
174    }
175
176    #[test]
177    fn tuple_round_trips() {
178        let connection = Connection::open_memory("tuple_round_trips");
179        connection
180            .exec(indoc! {"
181                CREATE TABLE test (
182                    text TEXT,
183                    integer INTEGER,
184                    blob BLOB
185                );"})
186            .unwrap();
187
188        let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
189        let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
190
191        let mut insert = connection
192            .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
193            .unwrap();
194
195        insert
196            .with_bindings(tuple1.clone())
197            .unwrap()
198            .exec()
199            .unwrap();
200        insert
201            .with_bindings(tuple2.clone())
202            .unwrap()
203            .exec()
204            .unwrap();
205
206        assert_eq!(
207            connection
208                .prepare("SELECT * FROM test")
209                .unwrap()
210                .rows::<(String, usize, Vec<u8>)>()
211                .unwrap(),
212            vec![tuple1, tuple2]
213        );
214    }
215
216    #[test]
217    fn backup_works() {
218        let connection1 = Connection::open_memory("backup_works");
219        connection1
220            .exec(indoc! {"
221                CREATE TABLE blobs (
222                    data BLOB
223                );"})
224            .unwrap();
225        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
226        let mut write = connection1
227            .prepare("INSERT INTO blobs (data) VALUES (?);")
228            .unwrap();
229        write.bind_blob(1, blob).unwrap();
230        write.exec().unwrap();
231
232        // Backup connection1 to connection2
233        let connection2 = Connection::open_memory("backup_works_other");
234        connection1.backup_main(&connection2).unwrap();
235
236        // Delete the added blob and verify its deleted on the other side
237        let read_blobs = connection1
238            .prepare("SELECT * FROM blobs;")
239            .unwrap()
240            .rows::<Vec<u8>>()
241            .unwrap();
242        assert_eq!(read_blobs, vec![blob]);
243    }
244
245    #[test]
246    fn test_kv_store() -> anyhow::Result<()> {
247        let connection = Connection::open_memory("kv_store");
248
249        Migration::new(
250            "kv",
251            &["CREATE TABLE kv_store(
252                key TEXT PRIMARY KEY,
253                value TEXT NOT NULL
254            ) STRICT;"],
255        )
256        .run(&connection)
257        .unwrap();
258
259        let mut stmt = connection.prepare("INSERT INTO kv_store(key, value) VALUES(?, ?)")?;
260        stmt.bind_text(1, "a").unwrap();
261        stmt.bind_text(2, "b").unwrap();
262        stmt.exec().unwrap();
263        let id = connection.last_insert_id();
264
265        let res = connection
266            .prepare("SELECT key, value FROM kv_store WHERE rowid = ?")?
267            .with_bindings(id)?
268            .row::<(String, String)>()?;
269
270        assert_eq!(res, ("a".to_string(), "b".to_string()));
271
272        Ok(())
273    }
274}