diff --git a/crates/db/src/db.rs b/crates/db/src/db.rs index e5740c5edb99b694ccc8f4f82be8d10711e1e2ed..857b5f273eb2d506f1e245e49798a1b05bf73ef9 100644 --- a/crates/db/src/db.rs +++ b/crates/db/src/db.rs @@ -10,6 +10,8 @@ use std::path::Path; use anyhow::Result; use indoc::indoc; +use kvp::KVP_MIGRATION; +use pane::PANE_MIGRATIONS; use sqlez::connection::Connection; use sqlez::thread_safe_connection::ThreadSafeConnection; @@ -42,7 +44,8 @@ impl Db { PRAGMA synchronous=NORMAL; PRAGMA foreign_keys=TRUE; PRAGMA case_sensitive_like=TRUE; - "}), + "}) + .with_migrations(&[KVP_MIGRATION, WORKSPACES_MIGRATION, PANE_MIGRATIONS]), ) } @@ -64,7 +67,7 @@ impl Db { pub fn write_to>(&self, dest: P) -> Result<()> { let destination = Connection::open_file(dest.as_ref().to_string_lossy().as_ref()); - self.backup(&destination) + self.backup_main(&destination) } } diff --git a/crates/db/src/kvp.rs b/crates/db/src/kvp.rs index 96f13d8040bf6e289711b46462ccf88d1eafc735..6db99831f765d03a0faa9cc43ec951cf0450c7bb 100644 --- a/crates/db/src/kvp.rs +++ b/crates/db/src/kvp.rs @@ -1,55 +1,38 @@ -use anyhow::Result; -use rusqlite::OptionalExtension; - use super::Db; - -pub(crate) const KVP_M_1: &str = " -CREATE TABLE kv_store( - key TEXT PRIMARY KEY, - value TEXT NOT NULL -) STRICT; -"; +use anyhow::Result; +use indoc::indoc; +use sqlez::migrations::Migration; + +pub(crate) const KVP_MIGRATION: Migration = Migration::new( + "kvp", + &[indoc! {" + CREATE TABLE kv_store( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) STRICT; + "}], +); impl Db { pub fn read_kvp(&self, key: &str) -> Result> { - self.real() - .map(|db| { - let lock = db.connection.lock(); - let mut stmt = lock.prepare_cached("SELECT value FROM kv_store WHERE key = (?)")?; - - Ok(stmt.query_row([key], |row| row.get(0)).optional()?) - }) - .unwrap_or(Ok(None)) + self.0 + .prepare("SELECT value FROM kv_store WHERE key = (?)")? + .bind(key)? + .maybe_row() } pub fn write_kvp(&self, key: &str, value: &str) -> Result<()> { - self.real() - .map(|db| { - let lock = db.connection.lock(); - - let mut stmt = lock.prepare_cached( - "INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))", - )?; - - stmt.execute([key, value])?; - - Ok(()) - }) - .unwrap_or(Ok(())) + self.0 + .prepare("INSERT OR REPLACE INTO kv_store(key, value) VALUES (?, ?)")? + .bind((key, value))? + .exec() } pub fn delete_kvp(&self, key: &str) -> Result<()> { - self.real() - .map(|db| { - let lock = db.connection.lock(); - - let mut stmt = lock.prepare_cached("DELETE FROM kv_store WHERE key = (?)")?; - - stmt.execute([key])?; - - Ok(()) - }) - .unwrap_or(Ok(())) + self.0 + .prepare("DELETE FROM kv_store WHERE key = (?)")? + .bind(key)? + .exec() } } diff --git a/crates/db/src/migrations.rs b/crates/db/src/migrations.rs index 8caa528fc1ef607405994338265b1460dc34f5de..a95654f420fa418d4c82e3703cf1328e000f5e20 100644 --- a/crates/db/src/migrations.rs +++ b/crates/db/src/migrations.rs @@ -1,16 +1,14 @@ -use rusqlite_migration::{Migrations, M}; +// // use crate::items::ITEMS_M_1; +// use crate::{kvp::KVP_M_1, pane::PANE_M_1, WORKSPACES_MIGRATION}; -// use crate::items::ITEMS_M_1; -use crate::{kvp::KVP_M_1, pane::PANE_M_1, WORKSPACE_M_1}; - -// This must be ordered by development time! Only ever add new migrations to the end!! -// Bad things will probably happen if you don't monotonically edit this vec!!!! -// And no re-ordering ever!!!!!!!!!! The results of these migrations are on the user's -// file system and so everything we do here is locked in _f_o_r_e_v_e_r_. -lazy_static::lazy_static! { - pub static ref MIGRATIONS: Migrations<'static> = Migrations::new(vec![ - M::up(KVP_M_1), - M::up(WORKSPACE_M_1), - M::up(PANE_M_1) - ]); -} +// // This must be ordered by development time! Only ever add new migrations to the end!! +// // Bad things will probably happen if you don't monotonically edit this vec!!!! +// // And no re-ordering ever!!!!!!!!!! The results of these migrations are on the user's +// // file system and so everything we do here is locked in _f_o_r_e_v_e_r_. +// lazy_static::lazy_static! { +// pub static ref MIGRATIONS: Migrations<'static> = Migrations::new(vec![ +// M::up(KVP_M_1), +// M::up(WORKSPACE_M_1), +// M::up(PANE_M_1) +// ]); +// } diff --git a/crates/db/src/pane.rs b/crates/db/src/pane.rs index 0a1812c60cc68a38c2e4238cadb620a923b7f28a..0716d19b1d209a52754159bf5bcc461a9936ed75 100644 --- a/crates/db/src/pane.rs +++ b/crates/db/src/pane.rs @@ -1,15 +1,14 @@ use gpui::Axis; +use indoc::indoc; +use sqlez::migrations::Migration; -use rusqlite::{OptionalExtension, Connection}; -use serde::{Deserialize, Serialize}; -use serde_rusqlite::{from_row, to_params_named}; use crate::{items::ItemId, workspace::WorkspaceId}; use super::Db; -pub(crate) const PANE_M_1: &str = " +pub(crate) const PANE_MIGRATIONS: Migration = Migration::new("pane", &[indoc! {" CREATE TABLE dock_panes( dock_pane_id INTEGER PRIMARY KEY, workspace_id INTEGER NOT NULL, @@ -64,7 +63,7 @@ CREATE TABLE dock_items( FOREIGN KEY(dock_pane_id) REFERENCES dock_panes(dock_pane_id) ON DELETE CASCADE, FOREIGN KEY(item_id) REFERENCES items(item_id)ON DELETE CASCADE ) STRICT; -"; +"}]); // We have an many-branched, unbalanced tree with three types: // Pane Groups @@ -140,7 +139,7 @@ pub struct SerializedPane { //********* CURRENTLY IN USE TYPES: ********* -#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Default, Debug, PartialEq, Eq)] pub enum DockAnchor { #[default] Bottom, @@ -148,7 +147,7 @@ pub enum DockAnchor { Expanded, } -#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Default, Debug, PartialEq, Eq)] pub struct SerializedDockPane { pub anchor_position: DockAnchor, pub visible: bool, @@ -160,7 +159,7 @@ impl SerializedDockPane { } } -#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Default, Debug, PartialEq, Eq)] pub(crate) struct DockRow { workspace_id: WorkspaceId, anchor_position: DockAnchor, @@ -298,12 +297,11 @@ mod tests { let workspace = db.workspace_for_roots(&["/tmp"]); let dock_pane = SerializedDockPane { - workspace_id: workspace.workspace_id, anchor_position: DockAnchor::Expanded, visible: true, }; - db.save_dock_pane(&dock_pane); + db.save_dock_pane(workspace.workspace_id, dock_pane); let new_workspace = db.workspace_for_roots(&["/tmp"]); diff --git a/crates/db/src/workspace.rs b/crates/db/src/workspace.rs index 5237caa23c2b0bf243bafc0ef67a890f47ab0598..16ff0e78c050b453ccfe69ab426d5df7931ff754 100644 --- a/crates/db/src/workspace.rs +++ b/crates/db/src/workspace.rs @@ -6,12 +6,12 @@ use std::{ os::unix::prelude::OsStrExt, path::{Path, PathBuf}, sync::Arc, - time::{SystemTime, UNIX_EPOCH}, }; -use anyhow::Result; use indoc::indoc; -use sqlez::{connection::Connection, migrations::Migration}; +use sqlez::{ + connection::Connection, migrations::Migration, +}; use crate::pane::SerializedDockPane; @@ -20,8 +20,8 @@ use super::Db; // If you need to debug the worktree root code, change 'BLOB' here to 'TEXT' for easier debugging // you might want to update some of the parsing code as well, I've left the variations in but commented // out. This will panic if run on an existing db that has already been migrated -const WORKSPACES_MIGRATION: Migration = Migration::new( - "migrations", +pub(crate) const WORKSPACES_MIGRATION: Migration = Migration::new( + "workspace", &[indoc! {" CREATE TABLE workspaces( workspace_id INTEGER PRIMARY KEY, @@ -53,8 +53,8 @@ pub struct SerializedWorkspace { } impl Db { - /// Finds or creates a workspace id for the given set of worktree roots. If the passed worktree roots is empty, return the - /// the last workspace id + /// Finds or creates a workspace id for the given set of worktree roots. If the passed worktree roots is empty, + /// returns the last workspace which was updated pub fn workspace_for_roots

(&self, worktree_roots: &[P]) -> SerializedWorkspace where P: AsRef + Debug, @@ -80,23 +80,21 @@ impl Db { where P: AsRef + Debug, { - let result = (|| { - let tx = self.transaction()?; - tx.execute("INSERT INTO workspaces(last_opened_timestamp) VALUES" (?), [current_millis()?])?; - - let id = WorkspaceId(tx.last_insert_rowid()); + let res = self.with_savepoint("make_new_workspace", |conn| { + let workspace_id = WorkspaceId( + conn.prepare("INSERT INTO workspaces DEFAULT VALUES")? + .insert()?, + ); - update_worktree_roots(&tx, &id, worktree_roots)?; - - tx.commit()?; + update_worktree_roots(conn, &workspace_id, worktree_roots)?; Ok(SerializedWorkspace { - workspace_id: id, + workspace_id, dock_pane: None, }) - })(); + }); - match result { + match res { Ok(serialized_workspace) => serialized_workspace, Err(err) => { log::error!("Failed to insert new workspace into DB: {}", err); @@ -109,19 +107,13 @@ impl Db { where P: AsRef + Debug, { - self.real() - .map(|db| { - let lock = db.connection.lock(); - - match get_workspace_id(worktree_roots, &lock) { - Ok(workspace_id) => workspace_id, - Err(err) => { - log::error!("Failed to get workspace_id: {}", err); - None - } - } - }) - .unwrap_or(None) + match get_workspace_id(worktree_roots, &self) { + Ok(workspace_id) => workspace_id, + Err(err) => { + log::error!("Failed to get workspace_id: {}", err); + None + } + } } // fn get_workspace_row(&self, workspace_id: WorkspaceId) -> WorkspaceRow { @@ -135,123 +127,73 @@ impl Db { where P: AsRef + Debug, { - fn logic

( - connection: &mut Connection, - workspace_id: &WorkspaceId, - worktree_roots: &[P], - ) -> Result<()> - where - P: AsRef + Debug, - { - let tx = connection.transaction()?; - update_worktree_roots(&tx, workspace_id, worktree_roots)?; - tx.commit()?; - Ok(()) + match self.with_savepoint("update_worktrees", |conn| { + update_worktree_roots(conn, workspace_id, worktree_roots) + }) { + Ok(_) => {} + Err(err) => log::error!( + "Failed to update workspace {:?} with roots {:?}, error: {}", + workspace_id, + worktree_roots, + err + ), } - - self.real().map(|db| { - let mut lock = db.connection.lock(); - - match logic(&mut lock, workspace_id, worktree_roots) { - Ok(_) => {} - Err(err) => { - log::error!( - "Failed to update the worktree roots for {:?}, roots: {:?}, error: {}", - workspace_id, - worktree_roots, - err - ); - } - } - }); } fn last_workspace_id(&self) -> Option { - fn logic(connection: &mut Connection) -> Result> { - let mut stmt = connection.prepare( + let res = self + .prepare( "SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT 1", - )?; + ) + .and_then(|stmt| stmt.maybe_row()) + .map(|row| row.map(|id| WorkspaceId(id))); - Ok(stmt - .query_row([], |row| Ok(WorkspaceId(row.get(0)?))) - .optional()?) + match res { + Ok(result) => result, + Err(err) => { + log::error!("Failed to get last workspace id, err: {}", err); + return None; + } } - - self.real() - .map(|db| { - let mut lock = db.connection.lock(); - - match logic(&mut lock) { - Ok(result) => result, - Err(err) => { - log::error!("Failed to get last workspace id, err: {}", err); - None - } - } - }) - .unwrap_or(None) } /// Returns the previous workspace ids sorted by last modified along with their opened worktree roots pub fn recent_workspaces(&self, limit: usize) -> Vec<(WorkspaceId, Vec>)> { - fn logic( - connection: &mut Connection, - limit: usize, - ) -> Result>)>, anyhow::Error> { - let tx = connection.transaction()?; - let result = { - let mut stmt = tx.prepare( - "SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT ?", - )?; - - let workspace_ids = stmt - .query_map([limit], |row| Ok(WorkspaceId(row.get(0)?)))? - .collect::, rusqlite::Error>>()?; - - let mut result = Vec::new(); - let mut stmt = - tx.prepare("SELECT worktree_root FROM worktree_roots WHERE workspace_id = ?")?; - for workspace_id in workspace_ids { - let roots = stmt - .query_map([workspace_id.0], |row| { - let row = row.get::<_, Vec>(0)?; - Ok(PathBuf::from(OsStr::from_bytes(&row)).into()) - // If you need to debug this, here's the string parsing: - // let row = row.get::<_, String>(0)?; - // Ok(PathBuf::from(row).into()) - })? - .collect::, rusqlite::Error>>()?; - result.push((workspace_id, roots)) - } - - result - }; - tx.commit()?; - return Ok(result); - } - - self.real() - .map(|db| { - let mut lock = db.connection.lock(); - - match logic(&mut lock, limit) { - Ok(result) => result, - Err(err) => { - log::error!("Failed to get recent workspaces, err: {}", err); - Vec::new() - } - } - }) - .unwrap_or_else(|| Vec::new()) + let res = self.with_savepoint("recent_workspaces", |conn| { + let ids = conn.prepare("SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT ?")? + .bind(limit)? + .rows::()? + .iter() + .map(|row| WorkspaceId(*row)); + + let result = Vec::new(); + + let stmt = conn.prepare("SELECT worktree_root FROM worktree_roots WHERE workspace_id = ?")?; + for workspace_id in ids { + let roots = stmt.bind(workspace_id.0)? + .rows::>()? + .iter() + .map(|row| { + PathBuf::from(OsStr::from_bytes(&row)).into() + }) + .collect(); + result.push((workspace_id, roots)) + } + + + Ok(result) + }); + + match res { + Ok(result) => result, + Err(err) => { + log::error!("Failed to get recent workspaces, err: {}", err); + Vec::new() + } + } } } -fn current_millis() -> Result { - // SQLite only supports u64 integers, which means this code will trigger - // undefined behavior in 584 million years. It's probably fine. - Ok(SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64) -} - fn update_worktree_roots

( connection: &Connection, workspace_id: &WorkspaceId, @@ -265,33 +207,32 @@ where if let Some(preexisting_id) = preexisting_id { if preexisting_id != *workspace_id { // Should also delete fields in other tables with cascading updates - connection.execute( + connection.prepare( "DELETE FROM workspaces WHERE workspace_id = ?", - [preexisting_id.0], - )?; + )? + .bind(preexisting_id.0)? + .exec()?; } } - connection.execute( - "DELETE FROM worktree_roots WHERE workspace_id = ?", - [workspace_id.0], - )?; + connection + .prepare("DELETE FROM worktree_roots WHERE workspace_id = ?")? + .bind(workspace_id.0)? + .exec()?; for root in worktree_roots { let path = root.as_ref().as_os_str().as_bytes(); // If you need to debug this, here's the string parsing: // let path = root.as_ref().to_string_lossy().to_string(); - connection.execute( - "INSERT INTO worktree_roots(workspace_id, worktree_root) VALUES (?, ?)", - params![workspace_id.0, path], - )?; + connection.prepare("INSERT INTO worktree_roots(workspace_id, worktree_root) VALUES (?, ?)")? + .bind((workspace_id.0, path))? + .exec()?; } - connection.execute( - "UPDATE workspaces SET last_opened_timestamp = ? WHERE workspace_id = ?", - params![current_millis()?, workspace_id.0], - )?; + connection.prepare("UPDATE workspaces SET last_opened_timestamp = CURRENT_TIMESTAMP WHERE workspace_id = ?")? + .bind(workspace_id.0)? + .exec()?; Ok(()) } @@ -300,13 +241,6 @@ fn get_workspace_id

(worktree_roots: &[P], connection: &Connection) -> Result< where P: AsRef + Debug, { - // fn logic

( - // worktree_roots: &[P], - // connection: &Connection, - // ) -> Result, anyhow::Error> - // where - // P: AsRef + Debug, - // { // Short circuit if we can if worktree_roots.len() == 0 { return Ok(None); @@ -324,6 +258,7 @@ where } } array_binding_stmt.push(')'); + // Any workspace can have multiple independent paths, and these paths // can overlap in the database. Take this test data for example: // @@ -393,43 +328,19 @@ where // caching it. let mut stmt = connection.prepare(&query)?; // Make sure we bound the parameters correctly - debug_assert!(worktree_roots.len() + 1 == stmt.parameter_count()); + debug_assert!(worktree_roots.len() as i32 + 1 == stmt.parameter_count()); for i in 0..worktree_roots.len() { let path = &worktree_roots[i].as_ref().as_os_str().as_bytes(); // If you need to debug this, here's the string parsing: // let path = &worktree_roots[i].as_ref().to_string_lossy().to_string() - stmt.raw_bind_parameter(i + 1, path)? + stmt.bind_value(*path, i as i32 + 1); } // No -1, because SQLite is 1 based - stmt.raw_bind_parameter(worktree_roots.len() + 1, worktree_roots.len())?; - - let mut rows = stmt.raw_query(); - let row = rows.next(); - let result = if let Ok(Some(row)) = row { - Ok(Some(WorkspaceId(row.get(0)?))) - } else { - Ok(None) - }; + stmt.bind_value(worktree_roots.len(), worktree_roots.len() as i32 + 1)?; - // Ensure that this query only returns one row. The PRIMARY KEY constraint should catch this case - // but this is here to catch if someone refactors that constraint out. - debug_assert!(matches!(rows.next(), Ok(None))); - - result - // } - - // match logic(worktree_roots, connection) { - // Ok(result) => result, - // Err(err) => { - // log::error!( - // "Failed to get the workspace ID for paths {:?}, err: {}", - // worktree_roots, - // err - // ); - // None - // } - // } + stmt.maybe_row() + .map(|row| row.map(|id| WorkspaceId(id))) } #[cfg(test)] diff --git a/crates/sqlez/src/connection.rs b/crates/sqlez/src/connection.rs index 81bb9dfe78b27f5745b4e5b528c9910b1a027c22..be529784951a10ddd5a19d1b19b04774a9c3bfb2 100644 --- a/crates/sqlez/src/connection.rs +++ b/crates/sqlez/src/connection.rs @@ -53,6 +53,15 @@ impl Connection { self.persistent } + pub(crate) fn last_insert_id(&self) -> i64 { + unsafe { sqlite3_last_insert_rowid(self.sqlite3) } + } + + pub fn insert(&self, query: impl AsRef) -> Result { + self.exec(query)?; + Ok(self.last_insert_id()) + } + pub fn exec(&self, query: impl AsRef) -> Result<()> { unsafe { sqlite3_exec( @@ -140,9 +149,9 @@ mod test { connection .prepare("INSERT INTO text (text) VALUES (?);") .unwrap() - .bound(text) + .bind(text) .unwrap() - .run() + .exec() .unwrap(); assert_eq!( @@ -176,8 +185,8 @@ mod test { .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)") .unwrap(); - insert.bound(tuple1.clone()).unwrap().run().unwrap(); - insert.bound(tuple2.clone()).unwrap().run().unwrap(); + insert.bind(tuple1.clone()).unwrap().exec().unwrap(); + insert.bind(tuple2.clone()).unwrap().exec().unwrap(); assert_eq!( connection @@ -203,7 +212,7 @@ mod test { .prepare("INSERT INTO blobs (data) VALUES (?);") .unwrap(); write.bind_blob(1, blob).unwrap(); - write.run().unwrap(); + write.exec().unwrap(); // Backup connection1 to connection2 let connection2 = Connection::open_memory("backup_works_other"); diff --git a/crates/sqlez/src/migrations.rs b/crates/sqlez/src/migrations.rs index 4721b353c68e715a77f75676681566d43e2b8c8a..3c0771c0feb7a1f6df931f41f93618656c19b181 100644 --- a/crates/sqlez/src/migrations.rs +++ b/crates/sqlez/src/migrations.rs @@ -22,6 +22,7 @@ const MIGRATIONS_MIGRATION: Migration = Migration::new( "}], ); +#[derive(Debug)] pub struct Migration { domain: &'static str, migrations: &'static [&'static str], @@ -46,7 +47,7 @@ impl Migration { WHERE domain = ? ORDER BY step "})? - .bound(self.domain)? + .bind(self.domain)? .rows::<(String, usize, String)>()?; let mut store_completed_migration = connection @@ -71,8 +72,8 @@ impl Migration { connection.exec(migration)?; store_completed_migration - .bound((self.domain, index, *migration))? - .run()?; + .bind((self.domain, index, *migration))? + .exec()?; } Ok(()) @@ -162,9 +163,9 @@ mod test { .unwrap(); store_completed_migration - .bound((domain, i, i.to_string())) + .bind((domain, i, i.to_string())) .unwrap() - .run() + .exec() .unwrap(); } } diff --git a/crates/sqlez/src/savepoint.rs b/crates/sqlez/src/savepoint.rs index 749c0dc9487641c125d880d32817f0a0612636b9..50f28c73901d2382f1ef677425af1e835ea9678b 100644 --- a/crates/sqlez/src/savepoint.rs +++ b/crates/sqlez/src/savepoint.rs @@ -3,10 +3,36 @@ use anyhow::Result; use crate::connection::Connection; impl Connection { + // Run a set of commands within the context of a `SAVEPOINT name`. If the callback + // returns Err(_), the savepoint will be rolled back. Otherwise, the save + // point is released. + pub fn with_savepoint(&mut self, name: impl AsRef, f: F) -> Result + where + F: FnOnce(&mut Connection) -> Result, + { + let name = name.as_ref().to_owned(); + self.exec(format!("SAVEPOINT {}", &name))?; + let result = f(self); + match result { + Ok(_) => { + self.exec(format!("RELEASE {}", name))?; + } + Err(_) => { + self.exec(format!("ROLLBACK TO {}", name))?; + self.exec(format!("RELEASE {}", name))?; + } + } + result + } + // Run a set of commands within the context of a `SAVEPOINT name`. If the callback // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save // point is released. - pub fn with_savepoint(&mut self, name: impl AsRef, f: F) -> Result> + pub fn with_savepoint_rollback( + &mut self, + name: impl AsRef, + f: F, + ) -> Result> where F: FnOnce(&mut Connection) -> Result>, { @@ -50,15 +76,15 @@ mod tests { connection.with_savepoint("first", |save1| { save1 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .bound((save1_text, 1))? - .run()?; + .bind((save1_text, 1))? + .exec()?; assert!(save1 .with_savepoint("second", |save2| -> Result, anyhow::Error> { save2 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .bound((save2_text, 2))? - .run()?; + .bind((save2_text, 2))? + .exec()?; assert_eq!( save2 @@ -79,11 +105,34 @@ mod tests { vec![save1_text], ); - save1.with_savepoint("second", |save2| { + save1.with_savepoint_rollback::<(), _>("second", |save2| { save2 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? - .bound((save2_text, 2))? - .run()?; + .bind((save2_text, 2))? + .exec()?; + + assert_eq!( + save2 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text, save2_text], + ); + + Ok(None) + })?; + + assert_eq!( + save1 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text], + ); + + save1.with_savepoint_rollback("second", |save2| { + save2 + .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? + .bind((save2_text, 2))? + .exec()?; assert_eq!( save2 @@ -102,9 +151,16 @@ mod tests { vec![save1_text, save2_text], ); - Ok(Some(())) + Ok(()) })?; + assert_eq!( + connection + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text, save2_text], + ); + Ok(()) } } diff --git a/crates/sqlez/src/statement.rs b/crates/sqlez/src/statement.rs index 774cda0e344c4b85bf2f258937067361a6ff3aa2..ac57847774b1ad37e1f2c5a7d47f653e5d9a363e 100644 --- a/crates/sqlez/src/statement.rs +++ b/crates/sqlez/src/statement.rs @@ -60,6 +60,10 @@ impl<'a> Statement<'a> { } } + pub fn parameter_count(&self) -> i32 { + unsafe { sqlite3_bind_parameter_count(self.raw_statement) } + } + pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> { let index = index as c_int; let blob_pointer = blob.as_ptr() as *const _; @@ -175,8 +179,9 @@ impl<'a> Statement<'a> { Ok(str::from_utf8(slice)?) } - pub fn bind(&self, value: T) -> Result<()> { - value.bind(self, 1)?; + pub fn bind_value(&self, value: T, idx: i32) -> Result<()> { + debug_assert!(idx > 0); + value.bind(self, idx)?; Ok(()) } @@ -198,8 +203,8 @@ impl<'a> Statement<'a> { } } - pub fn bound(&mut self, bindings: impl Bind) -> Result<&mut Self> { - self.bind(bindings)?; + pub fn bind(&mut self, bindings: impl Bind) -> Result<&mut Self> { + self.bind_value(bindings, 1)?; Ok(self) } @@ -217,7 +222,12 @@ impl<'a> Statement<'a> { } } - pub fn run(&mut self) -> Result<()> { + pub fn insert(&mut self) -> Result { + self.exec()?; + Ok(self.connection.last_insert_id()) + } + + pub fn exec(&mut self) -> Result<()> { fn logic(this: &mut Statement) -> Result<()> { while this.step()? == StepResult::Row {} Ok(()) diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs index 8885edc2c0a52f1d6514be0d1c9fc8483966c410..53d49464bed97fff60d9f9aed17882161f5f5465 100644 --- a/crates/sqlez/src/thread_safe_connection.rs +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -3,12 +3,13 @@ use std::{ops::Deref, sync::Arc}; use connection::Connection; use thread_local::ThreadLocal; -use crate::connection; +use crate::{connection, migrations::Migration}; pub struct ThreadSafeConnection { uri: Arc, persistent: bool, initialize_query: Option<&'static str>, + migrations: Option<&'static [Migration]>, connection: Arc>, } @@ -18,6 +19,7 @@ impl ThreadSafeConnection { uri: Arc::from(uri), persistent, initialize_query: None, + migrations: None, connection: Default::default(), } } @@ -29,6 +31,11 @@ impl ThreadSafeConnection { self } + pub fn with_migrations(mut self, migrations: &'static [Migration]) -> Self { + self.migrations = Some(migrations); + self + } + /// Opens a new db connection with the initialized file path. This is internal and only /// called from the deref function. /// If opening fails, the connection falls back to a shared memory connection @@ -49,6 +56,7 @@ impl Clone for ThreadSafeConnection { uri: self.uri.clone(), persistent: self.persistent, initialize_query: self.initialize_query.clone(), + migrations: self.migrations.clone(), connection: self.connection.clone(), } } @@ -72,6 +80,14 @@ impl Deref for ThreadSafeConnection { )); } + if let Some(migrations) = self.migrations { + for migration in migrations { + migration + .run(&connection) + .expect(&format!("Migrations failed to execute: {:?}", migration)); + } + } + connection }) }