Finished implementing the workspace stuff

Mikayla Maki created

Change summary

crates/db/src/db.rs                        |   7 
crates/db/src/kvp.rs                       |  67 ++---
crates/db/src/migrations.rs                |  28 +-
crates/db/src/pane.rs                      |  18 
crates/db/src/workspace.rs                 | 281 ++++++++---------------
crates/sqlez/src/connection.rs             |  19 +
crates/sqlez/src/migrations.rs             |  11 
crates/sqlez/src/savepoint.rs              |  74 +++++
crates/sqlez/src/statement.rs              |  20 +
crates/sqlez/src/thread_safe_connection.rs |  18 +
10 files changed, 264 insertions(+), 279 deletions(-)

Detailed changes

crates/db/src/db.rs 🔗

@@ -10,6 +10,8 @@ use std::path::Path;
 
 use anyhow::Result;
 use indoc::indoc;
+use kvp::KVP_MIGRATION;
+use pane::PANE_MIGRATIONS;
 use sqlez::connection::Connection;
 use sqlez::thread_safe_connection::ThreadSafeConnection;
 
@@ -42,7 +44,8 @@ impl Db {
                     PRAGMA synchronous=NORMAL;
                     PRAGMA foreign_keys=TRUE;
                     PRAGMA case_sensitive_like=TRUE;
-                "}),
+                "})
+                .with_migrations(&[KVP_MIGRATION, WORKSPACES_MIGRATION, PANE_MIGRATIONS]),
         )
     }
 
@@ -64,7 +67,7 @@ impl Db {
 
     pub fn write_to<P: AsRef<Path>>(&self, dest: P) -> Result<()> {
         let destination = Connection::open_file(dest.as_ref().to_string_lossy().as_ref());
-        self.backup(&destination)
+        self.backup_main(&destination)
     }
 }
 

crates/db/src/kvp.rs 🔗

@@ -1,55 +1,38 @@
-use anyhow::Result;
-use rusqlite::OptionalExtension;
-
 use super::Db;
-
-pub(crate) const KVP_M_1: &str = "
-CREATE TABLE kv_store(
-    key TEXT PRIMARY KEY,
-    value TEXT NOT NULL
-) STRICT;
-";
+use anyhow::Result;
+use indoc::indoc;
+use sqlez::migrations::Migration;
+
+pub(crate) const KVP_MIGRATION: Migration = Migration::new(
+    "kvp",
+    &[indoc! {"
+    CREATE TABLE kv_store(
+        key TEXT PRIMARY KEY,
+        value TEXT NOT NULL
+    ) STRICT;
+    "}],
+);
 
 impl Db {
     pub fn read_kvp(&self, key: &str) -> Result<Option<String>> {
-        self.real()
-            .map(|db| {
-                let lock = db.connection.lock();
-                let mut stmt = lock.prepare_cached("SELECT value FROM kv_store WHERE key = (?)")?;
-
-                Ok(stmt.query_row([key], |row| row.get(0)).optional()?)
-            })
-            .unwrap_or(Ok(None))
+        self.0
+            .prepare("SELECT value FROM kv_store WHERE key = (?)")?
+            .bind(key)?
+            .maybe_row()
     }
 
     pub fn write_kvp(&self, key: &str, value: &str) -> Result<()> {
-        self.real()
-            .map(|db| {
-                let lock = db.connection.lock();
-
-                let mut stmt = lock.prepare_cached(
-                    "INSERT OR REPLACE INTO kv_store(key, value) VALUES ((?), (?))",
-                )?;
-
-                stmt.execute([key, value])?;
-
-                Ok(())
-            })
-            .unwrap_or(Ok(()))
+        self.0
+            .prepare("INSERT OR REPLACE INTO kv_store(key, value) VALUES (?, ?)")?
+            .bind((key, value))?
+            .exec()
     }
 
     pub fn delete_kvp(&self, key: &str) -> Result<()> {
-        self.real()
-            .map(|db| {
-                let lock = db.connection.lock();
-
-                let mut stmt = lock.prepare_cached("DELETE FROM kv_store WHERE key = (?)")?;
-
-                stmt.execute([key])?;
-
-                Ok(())
-            })
-            .unwrap_or(Ok(()))
+        self.0
+            .prepare("DELETE FROM kv_store WHERE key = (?)")?
+            .bind(key)?
+            .exec()
     }
 }
 

crates/db/src/migrations.rs 🔗

@@ -1,16 +1,14 @@
-use rusqlite_migration::{Migrations, M};
+// // use crate::items::ITEMS_M_1;
+// use crate::{kvp::KVP_M_1, pane::PANE_M_1, WORKSPACES_MIGRATION};
 
-// use crate::items::ITEMS_M_1;
-use crate::{kvp::KVP_M_1, pane::PANE_M_1, WORKSPACE_M_1};
-
-// This must be ordered by development time! Only ever add new migrations to the end!!
-// Bad things will probably happen if you don't monotonically edit this vec!!!!
-// And no re-ordering ever!!!!!!!!!! The results of these migrations are on the user's
-// file system and so everything we do here is locked in _f_o_r_e_v_e_r_.
-lazy_static::lazy_static! {
-    pub static ref MIGRATIONS: Migrations<'static> = Migrations::new(vec![
-        M::up(KVP_M_1),
-        M::up(WORKSPACE_M_1),
-        M::up(PANE_M_1)
-    ]);
-}
+// // This must be ordered by development time! Only ever add new migrations to the end!!
+// // Bad things will probably happen if you don't monotonically edit this vec!!!!
+// // And no re-ordering ever!!!!!!!!!! The results of these migrations are on the user's
+// // file system and so everything we do here is locked in _f_o_r_e_v_e_r_.
+// lazy_static::lazy_static! {
+//     pub static ref MIGRATIONS: Migrations<'static> = Migrations::new(vec![
+//         M::up(KVP_M_1),
+//         M::up(WORKSPACE_M_1),
+//         M::up(PANE_M_1)
+//     ]);
+// }

crates/db/src/pane.rs 🔗

@@ -1,15 +1,14 @@
 
 use gpui::Axis;
+use indoc::indoc;
+use sqlez::migrations::Migration;
 
-use rusqlite::{OptionalExtension, Connection};
-use serde::{Deserialize, Serialize};
-use serde_rusqlite::{from_row, to_params_named};
 
 use crate::{items::ItemId, workspace::WorkspaceId};
 
 use super::Db;
 
-pub(crate) const PANE_M_1: &str = "
+pub(crate) const PANE_MIGRATIONS: Migration = Migration::new("pane", &[indoc! {"
 CREATE TABLE dock_panes(
     dock_pane_id INTEGER PRIMARY KEY,
     workspace_id INTEGER NOT NULL,
@@ -64,7 +63,7 @@ CREATE TABLE dock_items(
     FOREIGN KEY(dock_pane_id) REFERENCES dock_panes(dock_pane_id) ON DELETE CASCADE,
     FOREIGN KEY(item_id) REFERENCES items(item_id)ON DELETE CASCADE
 ) STRICT;
-";
+"}]);
 
 // We have an many-branched, unbalanced tree with three types:
 // Pane Groups
@@ -140,7 +139,7 @@ pub struct SerializedPane {
 //********* CURRENTLY IN USE TYPES: *********
 
 
-#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)]
+#[derive(Default, Debug, PartialEq, Eq)]
 pub enum DockAnchor {
     #[default]
     Bottom,
@@ -148,7 +147,7 @@ pub enum DockAnchor {
     Expanded,
 }
 
-#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)]
+#[derive(Default, Debug, PartialEq, Eq)]
 pub struct SerializedDockPane {
     pub anchor_position: DockAnchor,
     pub visible: bool,
@@ -160,7 +159,7 @@ impl SerializedDockPane {
     }
 }
 
-#[derive(Default, Debug, PartialEq, Eq, Deserialize, Serialize)]
+#[derive(Default, Debug, PartialEq, Eq)]
 pub(crate) struct DockRow {
     workspace_id: WorkspaceId,
     anchor_position: DockAnchor,
@@ -298,12 +297,11 @@ mod tests {
         let workspace = db.workspace_for_roots(&["/tmp"]);
 
         let dock_pane = SerializedDockPane {
-            workspace_id: workspace.workspace_id,
             anchor_position: DockAnchor::Expanded,
             visible: true,
         };
 
-        db.save_dock_pane(&dock_pane);
+        db.save_dock_pane(workspace.workspace_id, dock_pane);
 
         let new_workspace = db.workspace_for_roots(&["/tmp"]);
 

crates/db/src/workspace.rs 🔗

@@ -6,12 +6,12 @@ use std::{
     os::unix::prelude::OsStrExt,
     path::{Path, PathBuf},
     sync::Arc,
-    time::{SystemTime, UNIX_EPOCH},
 };
 
-use anyhow::Result;
 use indoc::indoc;
-use sqlez::{connection::Connection, migrations::Migration};
+use sqlez::{
+    connection::Connection, migrations::Migration,
+};
 
 use crate::pane::SerializedDockPane;
 
@@ -20,8 +20,8 @@ use super::Db;
 // If you need to debug the worktree root code, change 'BLOB' here to 'TEXT' for easier debugging
 // you might want to update some of the parsing code as well, I've left the variations in but commented
 // out. This will panic if run on an existing db that has already been migrated
-const WORKSPACES_MIGRATION: Migration = Migration::new(
-    "migrations",
+pub(crate) const WORKSPACES_MIGRATION: Migration = Migration::new(
+    "workspace",
     &[indoc! {"
             CREATE TABLE workspaces(
                 workspace_id INTEGER PRIMARY KEY,
@@ -53,8 +53,8 @@ pub struct SerializedWorkspace {
 }
 
 impl Db {
-    /// Finds or creates a workspace id for the given set of worktree roots. If the passed worktree roots is empty, return the
-    /// the last workspace id
+    /// Finds or creates a workspace id for the given set of worktree roots. If the passed worktree roots is empty,
+    /// returns the last workspace which was updated
     pub fn workspace_for_roots<P>(&self, worktree_roots: &[P]) -> SerializedWorkspace
     where
         P: AsRef<Path> + Debug,
@@ -80,23 +80,21 @@ impl Db {
     where
         P: AsRef<Path> + Debug,
     {
-        let result = (|| {
-            let tx = self.transaction()?;
-            tx.execute("INSERT INTO workspaces(last_opened_timestamp) VALUES" (?), [current_millis()?])?;
-
-            let id = WorkspaceId(tx.last_insert_rowid());
+        let res = self.with_savepoint("make_new_workspace", |conn| {
+            let workspace_id = WorkspaceId(
+                conn.prepare("INSERT INTO workspaces DEFAULT VALUES")?
+                    .insert()?,
+            );
 
-            update_worktree_roots(&tx, &id, worktree_roots)?;
-
-            tx.commit()?;
+            update_worktree_roots(conn, &workspace_id, worktree_roots)?;
 
             Ok(SerializedWorkspace {
-                workspace_id: id,
+                workspace_id,
                 dock_pane: None,
             })
-        })();
+        });
 
-        match result {
+        match res {
             Ok(serialized_workspace) => serialized_workspace,
             Err(err) => {
                 log::error!("Failed to insert new workspace into DB: {}", err);
@@ -109,19 +107,13 @@ impl Db {
     where
         P: AsRef<Path> + Debug,
     {
-        self.real()
-            .map(|db| {
-                let lock = db.connection.lock();
-
-                match get_workspace_id(worktree_roots, &lock) {
-                    Ok(workspace_id) => workspace_id,
-                    Err(err) => {
-                        log::error!("Failed to get workspace_id: {}", err);
-                        None
-                    }
-                }
-            })
-            .unwrap_or(None)
+        match get_workspace_id(worktree_roots, &self) {
+            Ok(workspace_id) => workspace_id,
+            Err(err) => {
+                log::error!("Failed to get workspace_id: {}", err);
+                None
+            }
+        }
     }
 
     // fn get_workspace_row(&self, workspace_id: WorkspaceId) -> WorkspaceRow {
@@ -135,123 +127,73 @@ impl Db {
     where
         P: AsRef<Path> + Debug,
     {
-        fn logic<P>(
-            connection: &mut Connection,
-            workspace_id: &WorkspaceId,
-            worktree_roots: &[P],
-        ) -> Result<()>
-        where
-            P: AsRef<Path> + Debug,
-        {
-            let tx = connection.transaction()?;
-            update_worktree_roots(&tx, workspace_id, worktree_roots)?;
-            tx.commit()?;
-            Ok(())
+        match self.with_savepoint("update_worktrees", |conn| {
+            update_worktree_roots(conn, workspace_id, worktree_roots)
+        }) {
+            Ok(_) => {}
+            Err(err) => log::error!(
+                "Failed to update workspace {:?} with roots {:?}, error: {}",
+                workspace_id,
+                worktree_roots,
+                err
+            ),
         }
-
-        self.real().map(|db| {
-            let mut lock = db.connection.lock();
-
-            match logic(&mut lock, workspace_id, worktree_roots) {
-                Ok(_) => {}
-                Err(err) => {
-                    log::error!(
-                        "Failed to update the worktree roots for {:?}, roots: {:?}, error: {}",
-                        workspace_id,
-                        worktree_roots,
-                        err
-                    );
-                }
-            }
-        });
     }
 
     fn last_workspace_id(&self) -> Option<WorkspaceId> {
-        fn logic(connection: &mut Connection) -> Result<Option<WorkspaceId>> {
-            let mut stmt = connection.prepare(
+        let res = self
+            .prepare(
                 "SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT 1",
-            )?;
+            )
+            .and_then(|stmt| stmt.maybe_row())
+            .map(|row| row.map(|id| WorkspaceId(id)));
 
-            Ok(stmt
-                .query_row([], |row| Ok(WorkspaceId(row.get(0)?)))
-                .optional()?)
+        match res {
+            Ok(result) => result,
+            Err(err) => {
+                log::error!("Failed to get last workspace id, err: {}", err);
+                return None;
+            }
         }
-
-        self.real()
-            .map(|db| {
-                let mut lock = db.connection.lock();
-
-                match logic(&mut lock) {
-                    Ok(result) => result,
-                    Err(err) => {
-                        log::error!("Failed to get last workspace id, err: {}", err);
-                        None
-                    }
-                }
-            })
-            .unwrap_or(None)
     }
 
     /// Returns the previous workspace ids sorted by last modified along with their opened worktree roots
     pub fn recent_workspaces(&self, limit: usize) -> Vec<(WorkspaceId, Vec<Arc<Path>>)> {
-        fn logic(
-            connection: &mut Connection,
-            limit: usize,
-        ) -> Result<Vec<(WorkspaceId, Vec<Arc<Path>>)>, anyhow::Error> {
-            let tx = connection.transaction()?;
-            let result = {
-                let mut stmt = tx.prepare(
-                    "SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT ?",
-                )?;
-
-                let workspace_ids = stmt
-                    .query_map([limit], |row| Ok(WorkspaceId(row.get(0)?)))?
-                    .collect::<Result<Vec<_>, rusqlite::Error>>()?;
-
-                let mut result = Vec::new();
-                let mut stmt =
-                    tx.prepare("SELECT worktree_root FROM worktree_roots WHERE workspace_id = ?")?;
-                for workspace_id in workspace_ids {
-                    let roots = stmt
-                        .query_map([workspace_id.0], |row| {
-                            let row = row.get::<_, Vec<u8>>(0)?;
-                            Ok(PathBuf::from(OsStr::from_bytes(&row)).into())
-                            // If you need to debug this, here's the string parsing:
-                            // let row = row.get::<_, String>(0)?;
-                            // Ok(PathBuf::from(row).into())
-                        })?
-                        .collect::<Result<Vec<_>, rusqlite::Error>>()?;
-                    result.push((workspace_id, roots))
-                }
-
-                result
-            };
-            tx.commit()?;
-            return Ok(result);
-        }
-
-        self.real()
-            .map(|db| {
-                let mut lock = db.connection.lock();
-
-                match logic(&mut lock, limit) {
-                    Ok(result) => result,
-                    Err(err) => {
-                        log::error!("Failed to get recent workspaces, err: {}", err);
-                        Vec::new()
-                    }
-                }
-            })
-            .unwrap_or_else(|| Vec::new())
+        let res = self.with_savepoint("recent_workspaces", |conn| {
+            let ids = conn.prepare("SELECT workspace_id FROM workspaces ORDER BY last_opened_timestamp DESC LIMIT ?")?
+                .bind(limit)?
+                .rows::<i64>()?
+                .iter()
+                .map(|row| WorkspaceId(*row));
+            
+            let result = Vec::new();
+            
+            let stmt = conn.prepare("SELECT worktree_root FROM worktree_roots WHERE workspace_id = ?")?;
+            for workspace_id in ids {
+                let roots = stmt.bind(workspace_id.0)?
+                    .rows::<Vec<u8>>()?
+                    .iter()
+                    .map(|row| {
+                        PathBuf::from(OsStr::from_bytes(&row)).into()
+                    })
+                    .collect();
+                result.push((workspace_id, roots))
+            }
+            
+            
+            Ok(result)
+        });
+        
+        match res {
+            Ok(result) => result,
+            Err(err) => {
+                log::error!("Failed to get recent workspaces, err: {}", err);
+                Vec::new()
+            }
+        }        
     }
 }
 
-fn current_millis() -> Result<u64, anyhow::Error> {
-    // SQLite only supports u64 integers, which means this code will trigger
-    // undefined behavior in 584 million years. It's probably fine.
-    Ok(SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64)
-}
-
 fn update_worktree_roots<P>(
     connection: &Connection,
     workspace_id: &WorkspaceId,
@@ -265,33 +207,32 @@ where
     if let Some(preexisting_id) = preexisting_id {
         if preexisting_id != *workspace_id {
             // Should also delete fields in other tables with cascading updates
-            connection.execute(
+            connection.prepare(
                 "DELETE FROM workspaces WHERE workspace_id = ?",
-                [preexisting_id.0],
-            )?;
+            )?
+            .bind(preexisting_id.0)?
+            .exec()?;
         }
     }
 
-    connection.execute(
-        "DELETE FROM worktree_roots WHERE workspace_id = ?",
-        [workspace_id.0],
-    )?;
+    connection
+        .prepare("DELETE FROM worktree_roots WHERE workspace_id = ?")?
+        .bind(workspace_id.0)?
+        .exec()?;
 
     for root in worktree_roots {
         let path = root.as_ref().as_os_str().as_bytes();
         // If you need to debug this, here's the string parsing:
         // let path = root.as_ref().to_string_lossy().to_string();
 
-        connection.execute(
-            "INSERT INTO worktree_roots(workspace_id, worktree_root) VALUES (?, ?)",
-            params![workspace_id.0, path],
-        )?;
+        connection.prepare("INSERT INTO worktree_roots(workspace_id, worktree_root) VALUES (?, ?)")?
+            .bind((workspace_id.0, path))?
+            .exec()?;
     }
 
-    connection.execute(
-        "UPDATE workspaces SET last_opened_timestamp = ? WHERE workspace_id = ?",
-        params![current_millis()?, workspace_id.0],
-    )?;
+    connection.prepare("UPDATE workspaces SET last_opened_timestamp = CURRENT_TIMESTAMP WHERE workspace_id = ?")?
+        .bind(workspace_id.0)?
+        .exec()?;
 
     Ok(())
 }
@@ -300,13 +241,6 @@ fn get_workspace_id<P>(worktree_roots: &[P], connection: &Connection) -> Result<
 where
     P: AsRef<Path> + Debug,
 {
-    // fn logic<P>(
-    //     worktree_roots: &[P],
-    //     connection: &Connection,
-    // ) -> Result<Option<WorkspaceId>, anyhow::Error>
-    // where
-    //     P: AsRef<Path> + Debug,
-    // {
     // Short circuit if we can
     if worktree_roots.len() == 0 {
         return Ok(None);
@@ -324,6 +258,7 @@ where
         }
     }
     array_binding_stmt.push(')');
+    
     // Any workspace can have multiple independent paths, and these paths
     // can overlap in the database. Take this test data for example:
     //
@@ -393,43 +328,19 @@ where
     // caching it.
     let mut stmt = connection.prepare(&query)?;
     // Make sure we bound the parameters correctly
-    debug_assert!(worktree_roots.len() + 1 == stmt.parameter_count());
+    debug_assert!(worktree_roots.len() as i32 + 1 == stmt.parameter_count());
 
     for i in 0..worktree_roots.len() {
         let path = &worktree_roots[i].as_ref().as_os_str().as_bytes();
         // If you need to debug this, here's the string parsing:
         // let path = &worktree_roots[i].as_ref().to_string_lossy().to_string()
-        stmt.raw_bind_parameter(i + 1, path)?
+        stmt.bind_value(*path, i as i32 + 1);
     }
     // No -1, because SQLite is 1 based
-    stmt.raw_bind_parameter(worktree_roots.len() + 1, worktree_roots.len())?;
-
-    let mut rows = stmt.raw_query();
-    let row = rows.next();
-    let result = if let Ok(Some(row)) = row {
-        Ok(Some(WorkspaceId(row.get(0)?)))
-    } else {
-        Ok(None)
-    };
+    stmt.bind_value(worktree_roots.len(), worktree_roots.len() as i32 + 1)?;
 
-    // Ensure that this query only returns one row. The PRIMARY KEY constraint should catch this case
-    // but this is here to catch if someone refactors that constraint out.
-    debug_assert!(matches!(rows.next(), Ok(None)));
-
-    result
-    // }
-
-    // match logic(worktree_roots, connection) {
-    //     Ok(result) => result,
-    //     Err(err) => {
-    //         log::error!(
-    //             "Failed to get the workspace ID for paths {:?}, err: {}",
-    //             worktree_roots,
-    //             err
-    //         );
-    //         None
-    //     }
-    // }
+    stmt.maybe_row()
+        .map(|row| row.map(|id| WorkspaceId(id)))
 }
 
 #[cfg(test)]

crates/sqlez/src/connection.rs 🔗

@@ -53,6 +53,15 @@ impl Connection {
         self.persistent
     }
 
+    pub(crate) fn last_insert_id(&self) -> i64 {
+        unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
+    }
+
+    pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
+        self.exec(query)?;
+        Ok(self.last_insert_id())
+    }
+
     pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
         unsafe {
             sqlite3_exec(
@@ -140,9 +149,9 @@ mod test {
         connection
             .prepare("INSERT INTO text (text) VALUES (?);")
             .unwrap()
-            .bound(text)
+            .bind(text)
             .unwrap()
-            .run()
+            .exec()
             .unwrap();
 
         assert_eq!(
@@ -176,8 +185,8 @@ mod test {
             .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
             .unwrap();
 
-        insert.bound(tuple1.clone()).unwrap().run().unwrap();
-        insert.bound(tuple2.clone()).unwrap().run().unwrap();
+        insert.bind(tuple1.clone()).unwrap().exec().unwrap();
+        insert.bind(tuple2.clone()).unwrap().exec().unwrap();
 
         assert_eq!(
             connection
@@ -203,7 +212,7 @@ mod test {
             .prepare("INSERT INTO blobs (data) VALUES (?);")
             .unwrap();
         write.bind_blob(1, blob).unwrap();
-        write.run().unwrap();
+        write.exec().unwrap();
 
         // Backup connection1 to connection2
         let connection2 = Connection::open_memory("backup_works_other");

crates/sqlez/src/migrations.rs 🔗

@@ -22,6 +22,7 @@ const MIGRATIONS_MIGRATION: Migration = Migration::new(
     "}],
 );
 
+#[derive(Debug)]
 pub struct Migration {
     domain: &'static str,
     migrations: &'static [&'static str],
@@ -46,7 +47,7 @@ impl Migration {
                 WHERE domain = ?
                 ORDER BY step
                 "})?
-            .bound(self.domain)?
+            .bind(self.domain)?
             .rows::<(String, usize, String)>()?;
 
         let mut store_completed_migration = connection
@@ -71,8 +72,8 @@ impl Migration {
 
             connection.exec(migration)?;
             store_completed_migration
-                .bound((self.domain, index, *migration))?
-                .run()?;
+                .bind((self.domain, index, *migration))?
+                .exec()?;
         }
 
         Ok(())
@@ -162,9 +163,9 @@ mod test {
                 .unwrap();
 
             store_completed_migration
-                .bound((domain, i, i.to_string()))
+                .bind((domain, i, i.to_string()))
                 .unwrap()
-                .run()
+                .exec()
                 .unwrap();
         }
     }

crates/sqlez/src/savepoint.rs 🔗

@@ -3,10 +3,36 @@ use anyhow::Result;
 use crate::connection::Connection;
 
 impl Connection {
+    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
+    // returns Err(_), the savepoint will be rolled back. Otherwise, the save
+    // point is released.
+    pub fn with_savepoint<R, F>(&mut self, name: impl AsRef<str>, f: F) -> Result<R>
+    where
+        F: FnOnce(&mut Connection) -> Result<R>,
+    {
+        let name = name.as_ref().to_owned();
+        self.exec(format!("SAVEPOINT {}", &name))?;
+        let result = f(self);
+        match result {
+            Ok(_) => {
+                self.exec(format!("RELEASE {}", name))?;
+            }
+            Err(_) => {
+                self.exec(format!("ROLLBACK TO {}", name))?;
+                self.exec(format!("RELEASE {}", name))?;
+            }
+        }
+        result
+    }
+
     // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
     // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
     // point is released.
-    pub fn with_savepoint<F, R>(&mut self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
+    pub fn with_savepoint_rollback<R, F>(
+        &mut self,
+        name: impl AsRef<str>,
+        f: F,
+    ) -> Result<Option<R>>
     where
         F: FnOnce(&mut Connection) -> Result<Option<R>>,
     {
@@ -50,15 +76,15 @@ mod tests {
         connection.with_savepoint("first", |save1| {
             save1
                 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
-                .bound((save1_text, 1))?
-                .run()?;
+                .bind((save1_text, 1))?
+                .exec()?;
 
             assert!(save1
                 .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
                     save2
                         .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
-                        .bound((save2_text, 2))?
-                        .run()?;
+                        .bind((save2_text, 2))?
+                        .exec()?;
 
                     assert_eq!(
                         save2
@@ -79,11 +105,34 @@ mod tests {
                 vec![save1_text],
             );
 
-            save1.with_savepoint("second", |save2| {
+            save1.with_savepoint_rollback::<(), _>("second", |save2| {
                 save2
                     .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
-                    .bound((save2_text, 2))?
-                    .run()?;
+                    .bind((save2_text, 2))?
+                    .exec()?;
+
+                assert_eq!(
+                    save2
+                        .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+                        .rows::<String>()?,
+                    vec![save1_text, save2_text],
+                );
+
+                Ok(None)
+            })?;
+
+            assert_eq!(
+                save1
+                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+                    .rows::<String>()?,
+                vec![save1_text],
+            );
+
+            save1.with_savepoint_rollback("second", |save2| {
+                save2
+                    .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
+                    .bind((save2_text, 2))?
+                    .exec()?;
 
                 assert_eq!(
                     save2
@@ -102,9 +151,16 @@ mod tests {
                 vec![save1_text, save2_text],
             );
 
-            Ok(Some(()))
+            Ok(())
         })?;
 
+        assert_eq!(
+            connection
+                .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
+                .rows::<String>()?,
+            vec![save1_text, save2_text],
+        );
+
         Ok(())
     }
 }

crates/sqlez/src/statement.rs 🔗

@@ -60,6 +60,10 @@ impl<'a> Statement<'a> {
         }
     }
 
+    pub fn parameter_count(&self) -> i32 {
+        unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
+    }
+
     pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
         let index = index as c_int;
         let blob_pointer = blob.as_ptr() as *const _;
@@ -175,8 +179,9 @@ impl<'a> Statement<'a> {
         Ok(str::from_utf8(slice)?)
     }
 
-    pub fn bind<T: Bind>(&self, value: T) -> Result<()> {
-        value.bind(self, 1)?;
+    pub fn bind_value<T: Bind>(&self, value: T, idx: i32) -> Result<()> {
+        debug_assert!(idx > 0);
+        value.bind(self, idx)?;
         Ok(())
     }
 
@@ -198,8 +203,8 @@ impl<'a> Statement<'a> {
         }
     }
 
-    pub fn bound(&mut self, bindings: impl Bind) -> Result<&mut Self> {
-        self.bind(bindings)?;
+    pub fn bind(&mut self, bindings: impl Bind) -> Result<&mut Self> {
+        self.bind_value(bindings, 1)?;
         Ok(self)
     }
 
@@ -217,7 +222,12 @@ impl<'a> Statement<'a> {
         }
     }
 
-    pub fn run(&mut self) -> Result<()> {
+    pub fn insert(&mut self) -> Result<i64> {
+        self.exec()?;
+        Ok(self.connection.last_insert_id())
+    }
+
+    pub fn exec(&mut self) -> Result<()> {
         fn logic(this: &mut Statement) -> Result<()> {
             while this.step()? == StepResult::Row {}
             Ok(())

crates/sqlez/src/thread_safe_connection.rs 🔗

@@ -3,12 +3,13 @@ use std::{ops::Deref, sync::Arc};
 use connection::Connection;
 use thread_local::ThreadLocal;
 
-use crate::connection;
+use crate::{connection, migrations::Migration};
 
 pub struct ThreadSafeConnection {
     uri: Arc<str>,
     persistent: bool,
     initialize_query: Option<&'static str>,
+    migrations: Option<&'static [Migration]>,
     connection: Arc<ThreadLocal<Connection>>,
 }
 
@@ -18,6 +19,7 @@ impl ThreadSafeConnection {
             uri: Arc::from(uri),
             persistent,
             initialize_query: None,
+            migrations: None,
             connection: Default::default(),
         }
     }
@@ -29,6 +31,11 @@ impl ThreadSafeConnection {
         self
     }
 
+    pub fn with_migrations(mut self, migrations: &'static [Migration]) -> Self {
+        self.migrations = Some(migrations);
+        self
+    }
+
     /// Opens a new db connection with the initialized file path. This is internal and only
     /// called from the deref function.
     /// If opening fails, the connection falls back to a shared memory connection
@@ -49,6 +56,7 @@ impl Clone for ThreadSafeConnection {
             uri: self.uri.clone(),
             persistent: self.persistent,
             initialize_query: self.initialize_query.clone(),
+            migrations: self.migrations.clone(),
             connection: self.connection.clone(),
         }
     }
@@ -72,6 +80,14 @@ impl Deref for ThreadSafeConnection {
                 ));
             }
 
+            if let Some(migrations) = self.migrations {
+                for migration in migrations {
+                    migration
+                        .run(&connection)
+                        .expect(&format!("Migrations failed to execute: {:?}", migration));
+                }
+            }
+
             connection
         })
     }