crates/sqlez 🔗
@@ -1 +0,0 @@
-Subproject commit 10a78dbe535a0c270b6b4bc469fbbffe9fc8c36f
Mikayla Maki created
crates/sqlez | 1
crates/sqlez/.gitignore | 2
crates/sqlez/Cargo.lock | 150 ++++++++++
crates/sqlez/Cargo.toml | 12
crates/sqlez/src/bindable.rs | 209 ++++++++++++++
crates/sqlez/src/connection.rs | 220 +++++++++++++++
crates/sqlez/src/lib.rs | 6
crates/sqlez/src/migrations.rs | 261 ++++++++++++++++++
crates/sqlez/src/savepoint.rs | 110 +++++++
crates/sqlez/src/statement.rs | 342 ++++++++++++++++++++++++
crates/sqlez/src/thread_safe_connection.rs | 78 +++++
11 files changed, 1,390 insertions(+), 1 deletion(-)
@@ -1 +0,0 @@
-Subproject commit 10a78dbe535a0c270b6b4bc469fbbffe9fc8c36f
@@ -0,0 +1,2 @@
+debug/
+target/
@@ -0,0 +1,150 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "addr2line"
+version = "0.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b"
+dependencies = [
+ "gimli",
+]
+
+[[package]]
+name = "adler"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
+
+[[package]]
+name = "anyhow"
+version = "1.0.66"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
+dependencies = [
+ "backtrace",
+]
+
+[[package]]
+name = "backtrace"
+version = "0.3.66"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cab84319d616cfb654d03394f38ab7e6f0919e181b1b57e1fd15e7fb4077d9a7"
+dependencies = [
+ "addr2line",
+ "cc",
+ "cfg-if",
+ "libc",
+ "miniz_oxide",
+ "object",
+ "rustc-demangle",
+]
+
+[[package]]
+name = "cc"
+version = "1.0.73"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "gimli"
+version = "0.26.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d"
+
+[[package]]
+name = "indoc"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3"
+
+[[package]]
+name = "libc"
+version = "0.2.137"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
+
+[[package]]
+name = "libsqlite3-sys"
+version = "0.25.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa"
+dependencies = [
+ "cc",
+ "pkg-config",
+ "vcpkg",
+]
+
+[[package]]
+name = "memchr"
+version = "2.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
+
+[[package]]
+name = "miniz_oxide"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34"
+dependencies = [
+ "adler",
+]
+
+[[package]]
+name = "object"
+version = "0.29.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
+name = "once_cell"
+version = "1.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
+
+[[package]]
+name = "pkg-config"
+version = "0.3.26"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
+
+[[package]]
+name = "rustc-demangle"
+version = "0.1.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
+
+[[package]]
+name = "sqlez"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "indoc",
+ "libsqlite3-sys",
+ "thread_local",
+]
+
+[[package]]
+name = "thread_local"
+version = "1.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180"
+dependencies = [
+ "once_cell",
+]
+
+[[package]]
+name = "vcpkg"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
@@ -0,0 +1,12 @@
+[package]
+name = "sqlez"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+anyhow = { version = "1.0.38", features = ["backtrace"] }
+indoc = "1.0.7"
+libsqlite3-sys = { version = "0.25.2", features = ["bundled"] }
+thread_local = "1.1.4"
@@ -0,0 +1,209 @@
+use anyhow::Result;
+
+use crate::statement::{SqlType, Statement};
+
+pub trait Bind {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32>;
+}
+
+pub trait Column: Sized {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)>;
+}
+
+impl Bind for &[u8] {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_blob(start_index, self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Bind for Vec<u8> {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_blob(start_index, self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Column for Vec<u8> {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_blob(start_index)?;
+ Ok((Vec::from(result), start_index + 1))
+ }
+}
+
+impl Bind for f64 {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_double(start_index, *self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Column for f64 {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_double(start_index)?;
+ Ok((result, start_index + 1))
+ }
+}
+
+impl Bind for i32 {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_int(start_index, *self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Column for i32 {
+ fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_int(start_index)?;
+ Ok((result, start_index + 1))
+ }
+}
+
+impl Bind for i64 {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_int64(start_index, *self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Column for i64 {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_int64(start_index)?;
+ Ok((result, start_index + 1))
+ }
+}
+
+impl Bind for usize {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ (*self as i64).bind(statement, start_index)
+ }
+}
+
+impl Column for usize {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_int64(start_index)?;
+ Ok((result as usize, start_index + 1))
+ }
+}
+
+impl Bind for () {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_null(start_index)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Bind for &str {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_text(start_index, self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Bind for String {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ statement.bind_text(start_index, self)?;
+ Ok(start_index + 1)
+ }
+}
+
+impl Column for String {
+ fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let result = statement.column_text(start_index)?;
+ Ok((result.to_owned(), start_index + 1))
+ }
+}
+
+impl<T1: Bind, T2: Bind> Bind for (T1, T2) {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let next_index = self.0.bind(statement, start_index)?;
+ self.1.bind(statement, next_index)
+ }
+}
+
+impl<T1: Column, T2: Column> Column for (T1, T2) {
+ fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (first, next_index) = T1::column(statement, start_index)?;
+ let (second, next_index) = T2::column(statement, next_index)?;
+ Ok(((first, second), next_index))
+ }
+}
+
+impl<T1: Bind, T2: Bind, T3: Bind> Bind for (T1, T2, T3) {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let next_index = self.0.bind(statement, start_index)?;
+ let next_index = self.1.bind(statement, next_index)?;
+ self.2.bind(statement, next_index)
+ }
+}
+
+impl<T1: Column, T2: Column, T3: Column> Column for (T1, T2, T3) {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (first, next_index) = T1::column(statement, start_index)?;
+ let (second, next_index) = T2::column(statement, next_index)?;
+ let (third, next_index) = T3::column(statement, next_index)?;
+ Ok(((first, second, third), next_index))
+ }
+}
+
+impl<T1: Bind, T2: Bind, T3: Bind, T4: Bind> Bind for (T1, T2, T3, T4) {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let next_index = self.0.bind(statement, start_index)?;
+ let next_index = self.1.bind(statement, next_index)?;
+ let next_index = self.2.bind(statement, next_index)?;
+ self.3.bind(statement, next_index)
+ }
+}
+
+impl<T1: Column, T2: Column, T3: Column, T4: Column> Column for (T1, T2, T3, T4) {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let (first, next_index) = T1::column(statement, start_index)?;
+ let (second, next_index) = T2::column(statement, next_index)?;
+ let (third, next_index) = T3::column(statement, next_index)?;
+ let (forth, next_index) = T4::column(statement, next_index)?;
+ Ok(((first, second, third, forth), next_index))
+ }
+}
+
+impl<T: Bind> Bind for Option<T> {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ if let Some(this) = self {
+ this.bind(statement, start_index)
+ } else {
+ statement.bind_null(start_index)?;
+ Ok(start_index + 1)
+ }
+ }
+}
+
+impl<T: Column> Column for Option<T> {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ if let SqlType::Null = statement.column_type(start_index)? {
+ Ok((None, start_index + 1))
+ } else {
+ T::column(statement, start_index).map(|(result, next_index)| (Some(result), next_index))
+ }
+ }
+}
+
+impl<T: Bind, const COUNT: usize> Bind for [T; COUNT] {
+ fn bind(&self, statement: &Statement, start_index: i32) -> Result<i32> {
+ let mut current_index = start_index;
+ for binding in self {
+ current_index = binding.bind(statement, current_index)?
+ }
+
+ Ok(current_index)
+ }
+}
+
+impl<T: Column + Default + Copy, const COUNT: usize> Column for [T; COUNT] {
+ fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> {
+ let mut array = [Default::default(); COUNT];
+ let mut current_index = start_index;
+ for i in 0..COUNT {
+ (array[i], current_index) = T::column(statement, current_index)?;
+ }
+ Ok((array, current_index))
+ }
+}
@@ -0,0 +1,220 @@
+use std::{
+ ffi::{CStr, CString},
+ marker::PhantomData,
+};
+
+use anyhow::{anyhow, Result};
+use libsqlite3_sys::*;
+
+use crate::statement::Statement;
+
+pub struct Connection {
+ pub(crate) sqlite3: *mut sqlite3,
+ persistent: bool,
+ phantom: PhantomData<sqlite3>,
+}
+unsafe impl Send for Connection {}
+
+impl Connection {
+ fn open(uri: &str, persistent: bool) -> Result<Self> {
+ let mut connection = Self {
+ sqlite3: 0 as *mut _,
+ persistent,
+ phantom: PhantomData,
+ };
+
+ let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
+ unsafe {
+ sqlite3_open_v2(
+ CString::new(uri)?.as_ptr(),
+ &mut connection.sqlite3,
+ flags,
+ 0 as *const _,
+ );
+
+ connection.last_error()?;
+ }
+
+ Ok(connection)
+ }
+
+ /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
+ /// instead.
+ pub fn open_file(uri: &str) -> Self {
+ Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri))
+ }
+
+ pub fn open_memory(uri: &str) -> Self {
+ let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
+ Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
+ }
+
+ pub fn persistent(&self) -> bool {
+ self.persistent
+ }
+
+ pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
+ unsafe {
+ sqlite3_exec(
+ self.sqlite3,
+ CString::new(query.as_ref())?.as_ptr(),
+ None,
+ 0 as *mut _,
+ 0 as *mut _,
+ );
+ self.last_error()?;
+ }
+ Ok(())
+ }
+
+ pub fn prepare<T: AsRef<str>>(&self, query: T) -> Result<Statement> {
+ Statement::prepare(&self, query)
+ }
+
+ pub fn backup_main(&self, destination: &Connection) -> Result<()> {
+ unsafe {
+ let backup = sqlite3_backup_init(
+ destination.sqlite3,
+ CString::new("main")?.as_ptr(),
+ self.sqlite3,
+ CString::new("main")?.as_ptr(),
+ );
+ sqlite3_backup_step(backup, -1);
+ sqlite3_backup_finish(backup);
+ destination.last_error()
+ }
+ }
+
+ pub(crate) fn last_error(&self) -> Result<()> {
+ const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
+ unsafe {
+ let code = sqlite3_errcode(self.sqlite3);
+ if NON_ERROR_CODES.contains(&code) {
+ return Ok(());
+ }
+
+ let message = sqlite3_errmsg(self.sqlite3);
+ let message = if message.is_null() {
+ None
+ } else {
+ Some(
+ String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
+ .into_owned(),
+ )
+ };
+
+ Err(anyhow!(
+ "Sqlite call failed with code {} and message: {:?}",
+ code as isize,
+ message
+ ))
+ }
+ }
+}
+
+impl Drop for Connection {
+ fn drop(&mut self) {
+ unsafe { sqlite3_close(self.sqlite3) };
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use anyhow::Result;
+ use indoc::indoc;
+
+ use crate::connection::Connection;
+
+ #[test]
+ fn string_round_trips() -> Result<()> {
+ let connection = Connection::open_memory("string_round_trips");
+ connection
+ .exec(indoc! {"
+ CREATE TABLE text (
+ text TEXT
+ );"})
+ .unwrap();
+
+ let text = "Some test text";
+
+ connection
+ .prepare("INSERT INTO text (text) VALUES (?);")
+ .unwrap()
+ .bound(text)
+ .unwrap()
+ .run()
+ .unwrap();
+
+ assert_eq!(
+ &connection
+ .prepare("SELECT text FROM text;")
+ .unwrap()
+ .row::<String>()
+ .unwrap(),
+ text
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn tuple_round_trips() {
+ let connection = Connection::open_memory("tuple_round_trips");
+ connection
+ .exec(indoc! {"
+ CREATE TABLE test (
+ text TEXT,
+ integer INTEGER,
+ blob BLOB
+ );"})
+ .unwrap();
+
+ let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
+ let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
+
+ let mut insert = connection
+ .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
+ .unwrap();
+
+ insert.bound(tuple1.clone()).unwrap().run().unwrap();
+ insert.bound(tuple2.clone()).unwrap().run().unwrap();
+
+ assert_eq!(
+ connection
+ .prepare("SELECT * FROM test")
+ .unwrap()
+ .rows::<(String, usize, Vec<u8>)>()
+ .unwrap(),
+ vec![tuple1, tuple2]
+ );
+ }
+
+ #[test]
+ fn backup_works() {
+ let connection1 = Connection::open_memory("backup_works");
+ connection1
+ .exec(indoc! {"
+ CREATE TABLE blobs (
+ data BLOB
+ );"})
+ .unwrap();
+ let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
+ let mut write = connection1
+ .prepare("INSERT INTO blobs (data) VALUES (?);")
+ .unwrap();
+ write.bind_blob(1, blob).unwrap();
+ write.run().unwrap();
+
+ // Backup connection1 to connection2
+ let connection2 = Connection::open_memory("backup_works_other");
+ connection1.backup_main(&connection2).unwrap();
+
+ // Delete the added blob and verify its deleted on the other side
+ let read_blobs = connection1
+ .prepare("SELECT * FROM blobs;")
+ .unwrap()
+ .rows::<Vec<u8>>()
+ .unwrap();
+ assert_eq!(read_blobs, vec![blob]);
+ }
+}
@@ -0,0 +1,6 @@
+pub mod bindable;
+pub mod connection;
+pub mod migrations;
+pub mod savepoint;
+pub mod statement;
+pub mod thread_safe_connection;
@@ -0,0 +1,261 @@
+// Migrations are constructed by domain, and stored in a table in the connection db with domain name,
+// effected tables, actual query text, and order.
+// If a migration is run and any of the query texts don't match, the app panics on startup (maybe fallback
+// to creating a new db?)
+// Otherwise any missing migrations are run on the connection
+
+use anyhow::{anyhow, Result};
+use indoc::{formatdoc, indoc};
+
+use crate::connection::Connection;
+
+const MIGRATIONS_MIGRATION: Migration = Migration::new(
+ "migrations",
+ // The migrations migration must be infallable because it runs to completion
+ // with every call to migration run and is run unchecked.
+ &[indoc! {"
+ CREATE TABLE IF NOT EXISTS migrations (
+ domain TEXT,
+ step INTEGER,
+ migration TEXT
+ );
+ "}],
+);
+
+pub struct Migration {
+ domain: &'static str,
+ migrations: &'static [&'static str],
+}
+
+impl Migration {
+ pub const fn new(domain: &'static str, migrations: &'static [&'static str]) -> Self {
+ Self { domain, migrations }
+ }
+
+ fn run_unchecked(&self, connection: &Connection) -> Result<()> {
+ connection.exec(self.migrations.join(";\n"))
+ }
+
+ pub fn run(&self, connection: &Connection) -> Result<()> {
+ // Setup the migrations table unconditionally
+ MIGRATIONS_MIGRATION.run_unchecked(connection)?;
+
+ let completed_migrations = connection
+ .prepare(indoc! {"
+ SELECT domain, step, migration FROM migrations
+ WHERE domain = ?
+ ORDER BY step
+ "})?
+ .bound(self.domain)?
+ .rows::<(String, usize, String)>()?;
+
+ let mut store_completed_migration = connection
+ .prepare("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?;
+
+ for (index, migration) in self.migrations.iter().enumerate() {
+ if let Some((_, _, completed_migration)) = completed_migrations.get(index) {
+ if completed_migration != migration {
+ return Err(anyhow!(formatdoc! {"
+ Migration changed for {} at step {}
+
+ Stored migration:
+ {}
+
+ Proposed migration:
+ {}", self.domain, index, completed_migration, migration}));
+ } else {
+ // Migration already run. Continue
+ continue;
+ }
+ }
+
+ connection.exec(migration)?;
+ store_completed_migration
+ .bound((self.domain, index, *migration))?
+ .run()?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use indoc::indoc;
+
+ use crate::{connection::Connection, migrations::Migration};
+
+ #[test]
+ fn test_migrations_are_added_to_table() {
+ let connection = Connection::open_memory("migrations_are_added_to_table");
+
+ // Create first migration with a single step and run it
+ let mut migration = Migration::new(
+ "test",
+ &[indoc! {"
+ CREATE TABLE test1 (
+ a TEXT,
+ b TEXT
+ );"}],
+ );
+ migration.run(&connection).unwrap();
+
+ // Verify it got added to the migrations table
+ assert_eq!(
+ &connection
+ .prepare("SELECT (migration) FROM migrations")
+ .unwrap()
+ .rows::<String>()
+ .unwrap()[..],
+ migration.migrations
+ );
+
+ // Add another step to the migration and run it again
+ migration.migrations = &[
+ indoc! {"
+ CREATE TABLE test1 (
+ a TEXT,
+ b TEXT
+ );"},
+ indoc! {"
+ CREATE TABLE test2 (
+ c TEXT,
+ d TEXT
+ );"},
+ ];
+ migration.run(&connection).unwrap();
+
+ // Verify it is also added to the migrations table
+ assert_eq!(
+ &connection
+ .prepare("SELECT (migration) FROM migrations")
+ .unwrap()
+ .rows::<String>()
+ .unwrap()[..],
+ migration.migrations
+ );
+ }
+
+ #[test]
+ fn test_migration_setup_works() {
+ let connection = Connection::open_memory("migration_setup_works");
+
+ connection
+ .exec(indoc! {"CREATE TABLE IF NOT EXISTS migrations (
+ domain TEXT,
+ step INTEGER,
+ migration TEXT
+ );"})
+ .unwrap();
+
+ let mut store_completed_migration = connection
+ .prepare(indoc! {"
+ INSERT INTO migrations (domain, step, migration)
+ VALUES (?, ?, ?)"})
+ .unwrap();
+
+ let domain = "test_domain";
+ for i in 0..5 {
+ // Create a table forcing a schema change
+ connection
+ .exec(format!("CREATE TABLE table{} ( test TEXT );", i))
+ .unwrap();
+
+ store_completed_migration
+ .bound((domain, i, i.to_string()))
+ .unwrap()
+ .run()
+ .unwrap();
+ }
+ }
+
+ #[test]
+ fn migrations_dont_rerun() {
+ let connection = Connection::open_memory("migrations_dont_rerun");
+
+ // Create migration which clears a table
+ let migration = Migration::new("test", &["DELETE FROM test_table"]);
+
+ // Manually create the table for that migration with a row
+ connection
+ .exec(indoc! {"
+ CREATE TABLE test_table (
+ test_column INTEGER
+ );
+ INSERT INTO test_table (test_column) VALUES (1)"})
+ .unwrap();
+
+ assert_eq!(
+ connection
+ .prepare("SELECT * FROM test_table")
+ .unwrap()
+ .row::<usize>()
+ .unwrap(),
+ 1
+ );
+
+ // Run the migration verifying that the row got dropped
+ migration.run(&connection).unwrap();
+ assert_eq!(
+ connection
+ .prepare("SELECT * FROM test_table")
+ .unwrap()
+ .rows::<usize>()
+ .unwrap(),
+ Vec::new()
+ );
+
+ // Recreate the dropped row
+ connection
+ .exec("INSERT INTO test_table (test_column) VALUES (2)")
+ .unwrap();
+
+ // Run the same migration again and verify that the table was left unchanged
+ migration.run(&connection).unwrap();
+ assert_eq!(
+ connection
+ .prepare("SELECT * FROM test_table")
+ .unwrap()
+ .row::<usize>()
+ .unwrap(),
+ 2
+ );
+ }
+
+ #[test]
+ fn changed_migration_fails() {
+ let connection = Connection::open_memory("changed_migration_fails");
+
+ // Create a migration with two steps and run it
+ Migration::new(
+ "test migration",
+ &[
+ indoc! {"
+ CREATE TABLE test (
+ col INTEGER
+ )"},
+ indoc! {"
+ INSERT INTO test (col) VALUES (1)"},
+ ],
+ )
+ .run(&connection)
+ .unwrap();
+
+ // Create another migration with the same domain but different steps
+ let second_migration_result = Migration::new(
+ "test migration",
+ &[
+ indoc! {"
+ CREATE TABLE test (
+ color INTEGER
+ )"},
+ indoc! {"
+ INSERT INTO test (color) VALUES (1)"},
+ ],
+ )
+ .run(&connection);
+
+ // Verify new migration returns error when run
+ assert!(second_migration_result.is_err())
+ }
+}
@@ -0,0 +1,110 @@
+use anyhow::Result;
+
+use crate::connection::Connection;
+
+impl Connection {
+ // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
+ // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
+ // point is released.
+ pub fn with_savepoint<F, R>(&mut self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
+ where
+ F: FnOnce(&mut Connection) -> Result<Option<R>>,
+ {
+ let name = name.as_ref().to_owned();
+ self.exec(format!("SAVEPOINT {}", &name))?;
+ let result = f(self);
+ match result {
+ Ok(Some(_)) => {
+ self.exec(format!("RELEASE {}", name))?;
+ }
+ Ok(None) | Err(_) => {
+ self.exec(format!("ROLLBACK TO {}", name))?;
+ self.exec(format!("RELEASE {}", name))?;
+ }
+ }
+ result
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::connection::Connection;
+ use anyhow::Result;
+ use indoc::indoc;
+
+ #[test]
+ fn test_nested_savepoints() -> Result<()> {
+ let mut connection = Connection::open_memory("nested_savepoints");
+
+ connection
+ .exec(indoc! {"
+ CREATE TABLE text (
+ text TEXT,
+ idx INTEGER
+ );"})
+ .unwrap();
+
+ let save1_text = "test save1";
+ let save2_text = "test save2";
+
+ connection.with_savepoint("first", |save1| {
+ save1
+ .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
+ .bound((save1_text, 1))?
+ .run()?;
+
+ assert!(save1
+ .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
+ save2
+ .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
+ .bound((save2_text, 2))?
+ .run()?;
+
+ assert_eq!(
+ save2
+ .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+ .rows::<String>()?,
+ vec![save1_text, save2_text],
+ );
+
+ anyhow::bail!("Failed second save point :(")
+ })
+ .err()
+ .is_some());
+
+ assert_eq!(
+ save1
+ .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+ .rows::<String>()?,
+ vec![save1_text],
+ );
+
+ save1.with_savepoint("second", |save2| {
+ save2
+ .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
+ .bound((save2_text, 2))?
+ .run()?;
+
+ assert_eq!(
+ save2
+ .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+ .rows::<String>()?,
+ vec![save1_text, save2_text],
+ );
+
+ Ok(Some(()))
+ })?;
+
+ assert_eq!(
+ save1
+ .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+ .rows::<String>()?,
+ vec![save1_text, save2_text],
+ );
+
+ Ok(Some(()))
+ })?;
+
+ Ok(())
+ }
+}
@@ -0,0 +1,342 @@
+use std::ffi::{c_int, CString};
+use std::marker::PhantomData;
+use std::{slice, str};
+
+use anyhow::{anyhow, Context, Result};
+use libsqlite3_sys::*;
+
+use crate::bindable::{Bind, Column};
+use crate::connection::Connection;
+
+pub struct Statement<'a> {
+ raw_statement: *mut sqlite3_stmt,
+ connection: &'a Connection,
+ phantom: PhantomData<sqlite3_stmt>,
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum StepResult {
+ Row,
+ Done,
+ Misuse,
+ Other(i32),
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Debug)]
+pub enum SqlType {
+ Text,
+ Integer,
+ Blob,
+ Float,
+ Null,
+}
+
+impl<'a> Statement<'a> {
+ pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
+ let mut statement = Self {
+ raw_statement: 0 as *mut _,
+ connection,
+ phantom: PhantomData,
+ };
+
+ unsafe {
+ sqlite3_prepare_v2(
+ connection.sqlite3,
+ CString::new(query.as_ref())?.as_ptr(),
+ -1,
+ &mut statement.raw_statement,
+ 0 as *mut _,
+ );
+
+ connection.last_error().context("Prepare call failed.")?;
+ }
+
+ Ok(statement)
+ }
+
+ pub fn reset(&mut self) {
+ unsafe {
+ sqlite3_reset(self.raw_statement);
+ }
+ }
+
+ pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
+ let index = index as c_int;
+ let blob_pointer = blob.as_ptr() as *const _;
+ let len = blob.len() as c_int;
+ unsafe {
+ sqlite3_bind_blob(
+ self.raw_statement,
+ index,
+ blob_pointer,
+ len,
+ SQLITE_TRANSIENT(),
+ );
+ }
+ self.connection.last_error()
+ }
+
+ pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
+ let index = index as c_int;
+ let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
+
+ self.connection.last_error()?;
+ if pointer.is_null() {
+ return Ok(&[]);
+ }
+ let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
+ self.connection.last_error()?;
+ unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
+ }
+
+ pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
+ let index = index as c_int;
+
+ unsafe {
+ sqlite3_bind_double(self.raw_statement, index, double);
+ }
+ self.connection.last_error()
+ }
+
+ pub fn column_double(&self, index: i32) -> Result<f64> {
+ let index = index as c_int;
+ let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
+ self.connection.last_error()?;
+ Ok(result)
+ }
+
+ pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
+ let index = index as c_int;
+
+ unsafe {
+ sqlite3_bind_int(self.raw_statement, index, int);
+ }
+ self.connection.last_error()
+ }
+
+ pub fn column_int(&self, index: i32) -> Result<i32> {
+ let index = index as c_int;
+ let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
+ self.connection.last_error()?;
+ Ok(result)
+ }
+
+ pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
+ let index = index as c_int;
+ unsafe {
+ sqlite3_bind_int64(self.raw_statement, index, int);
+ }
+ self.connection.last_error()
+ }
+
+ pub fn column_int64(&self, index: i32) -> Result<i64> {
+ let index = index as c_int;
+ let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
+ self.connection.last_error()?;
+ Ok(result)
+ }
+
+ pub fn bind_null(&self, index: i32) -> Result<()> {
+ let index = index as c_int;
+ unsafe {
+ sqlite3_bind_null(self.raw_statement, index);
+ }
+ self.connection.last_error()
+ }
+
+ pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
+ let index = index as c_int;
+ let text_pointer = text.as_ptr() as *const _;
+ let len = text.len() as c_int;
+ unsafe {
+ sqlite3_bind_blob(
+ self.raw_statement,
+ index,
+ text_pointer,
+ len,
+ SQLITE_TRANSIENT(),
+ );
+ }
+ self.connection.last_error()
+ }
+
+ pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
+ let index = index as c_int;
+ let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
+
+ self.connection.last_error()?;
+ if pointer.is_null() {
+ return Ok("");
+ }
+ let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
+ self.connection.last_error()?;
+
+ let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
+ Ok(str::from_utf8(slice)?)
+ }
+
+ pub fn bind<T: Bind>(&self, value: T) -> Result<()> {
+ value.bind(self, 1)?;
+ Ok(())
+ }
+
+ pub fn column<T: Column>(&mut self) -> Result<T> {
+ let (result, _) = T::column(self, 0)?;
+ Ok(result)
+ }
+
+ pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
+ let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
+ self.connection.last_error()?;
+ match result {
+ SQLITE_INTEGER => Ok(SqlType::Integer),
+ SQLITE_FLOAT => Ok(SqlType::Float),
+ SQLITE_TEXT => Ok(SqlType::Text),
+ SQLITE_BLOB => Ok(SqlType::Blob),
+ SQLITE_NULL => Ok(SqlType::Null),
+ _ => Err(anyhow!("Column type returned was incorrect ")),
+ }
+ }
+
+ pub fn bound(&mut self, bindings: impl Bind) -> Result<&mut Self> {
+ self.bind(bindings)?;
+ Ok(self)
+ }
+
+ fn step(&mut self) -> Result<StepResult> {
+ unsafe {
+ match sqlite3_step(self.raw_statement) {
+ SQLITE_ROW => Ok(StepResult::Row),
+ SQLITE_DONE => Ok(StepResult::Done),
+ SQLITE_MISUSE => Ok(StepResult::Misuse),
+ other => self
+ .connection
+ .last_error()
+ .map(|_| StepResult::Other(other)),
+ }
+ }
+ }
+
+ pub fn run(&mut self) -> Result<()> {
+ fn logic(this: &mut Statement) -> Result<()> {
+ while this.step()? == StepResult::Row {}
+ Ok(())
+ }
+ let result = logic(self);
+ self.reset();
+ result
+ }
+
+ pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
+ fn logic<R>(
+ this: &mut Statement,
+ mut callback: impl FnMut(&mut Statement) -> Result<R>,
+ ) -> Result<Vec<R>> {
+ let mut mapped_rows = Vec::new();
+ while this.step()? == StepResult::Row {
+ mapped_rows.push(callback(this)?);
+ }
+ Ok(mapped_rows)
+ }
+
+ let result = logic(self, callback);
+ self.reset();
+ result
+ }
+
+ pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
+ self.map(|s| s.column::<R>())
+ }
+
+ pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
+ fn logic<R>(
+ this: &mut Statement,
+ callback: impl FnOnce(&mut Statement) -> Result<R>,
+ ) -> Result<R> {
+ if this.step()? != StepResult::Row {
+ return Err(anyhow!(
+ "Single(Map) called with query that returns no rows."
+ ));
+ }
+ callback(this)
+ }
+ let result = logic(self, callback);
+ self.reset();
+ result
+ }
+
+ pub fn row<R: Column>(&mut self) -> Result<R> {
+ self.single(|this| this.column::<R>())
+ }
+
+ pub fn maybe<R>(
+ &mut self,
+ callback: impl FnOnce(&mut Statement) -> Result<R>,
+ ) -> Result<Option<R>> {
+ fn logic<R>(
+ this: &mut Statement,
+ callback: impl FnOnce(&mut Statement) -> Result<R>,
+ ) -> Result<Option<R>> {
+ if this.step()? != StepResult::Row {
+ return Ok(None);
+ }
+ callback(this).map(|r| Some(r))
+ }
+ let result = logic(self, callback);
+ self.reset();
+ result
+ }
+
+ pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
+ self.maybe(|this| this.column::<R>())
+ }
+}
+
+impl<'a> Drop for Statement<'a> {
+ fn drop(&mut self) {
+ unsafe {
+ sqlite3_finalize(self.raw_statement);
+ self.connection
+ .last_error()
+ .expect("sqlite3 finalize failed for statement :(");
+ };
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use indoc::indoc;
+
+ use crate::{connection::Connection, statement::StepResult};
+
+ #[test]
+ fn blob_round_trips() {
+ let connection1 = Connection::open_memory("blob_round_trips");
+ connection1
+ .exec(indoc! {"
+ CREATE TABLE blobs (
+ data BLOB
+ );"})
+ .unwrap();
+
+ let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
+
+ let mut write = connection1
+ .prepare("INSERT INTO blobs (data) VALUES (?);")
+ .unwrap();
+ write.bind_blob(1, blob).unwrap();
+ assert_eq!(write.step().unwrap(), StepResult::Done);
+
+ // Read the blob from the
+ let connection2 = Connection::open_memory("blob_round_trips");
+ let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
+ assert_eq!(read.step().unwrap(), StepResult::Row);
+ assert_eq!(read.column_blob(0).unwrap(), blob);
+ assert_eq!(read.step().unwrap(), StepResult::Done);
+
+ // Delete the added blob and verify its deleted on the other side
+ connection2.exec("DELETE FROM blobs;").unwrap();
+ let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
+ assert_eq!(read.step().unwrap(), StepResult::Done);
+ }
+}
@@ -0,0 +1,78 @@
+use std::{ops::Deref, sync::Arc};
+
+use connection::Connection;
+use thread_local::ThreadLocal;
+
+use crate::connection;
+
+pub struct ThreadSafeConnection {
+ uri: Arc<str>,
+ persistent: bool,
+ initialize_query: Option<&'static str>,
+ connection: Arc<ThreadLocal<Connection>>,
+}
+
+impl ThreadSafeConnection {
+ pub fn new(uri: &str, persistent: bool) -> Self {
+ Self {
+ uri: Arc::from(uri),
+ persistent,
+ initialize_query: None,
+ connection: Default::default(),
+ }
+ }
+
+ /// Sets the query to run every time a connection is opened. This must
+ /// be infallible (EG only use pragma statements)
+ pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self {
+ self.initialize_query = Some(initialize_query);
+ self
+ }
+
+ /// Opens a new db connection with the initialized file path. This is internal and only
+ /// called from the deref function.
+ /// If opening fails, the connection falls back to a shared memory connection
+ fn open_file(&self) -> Connection {
+ Connection::open_file(self.uri.as_ref())
+ }
+
+ /// Opens a shared memory connection using the file path as the identifier. This unwraps
+ /// as we expect it always to succeed
+ fn open_shared_memory(&self) -> Connection {
+ Connection::open_memory(self.uri.as_ref())
+ }
+}
+
+impl Clone for ThreadSafeConnection {
+ fn clone(&self) -> Self {
+ Self {
+ uri: self.uri.clone(),
+ persistent: self.persistent,
+ initialize_query: self.initialize_query.clone(),
+ connection: self.connection.clone(),
+ }
+ }
+}
+
+impl Deref for ThreadSafeConnection {
+ type Target = Connection;
+
+ fn deref(&self) -> &Self::Target {
+ self.connection.get_or(|| {
+ let connection = if self.persistent {
+ self.open_file()
+ } else {
+ self.open_shared_memory()
+ };
+
+ if let Some(initialize_query) = self.initialize_query {
+ connection.exec(initialize_query).expect(&format!(
+ "Initialize query failed to execute: {}",
+ initialize_query
+ ));
+ }
+
+ connection
+ })
+ }
+}