savepoint.rs

  1use anyhow::Result;
  2
  3use crate::connection::Connection;
  4
  5impl Connection {
  6    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
  7    // returns Err(_), the savepoint will be rolled back. Otherwise, the save
  8    // point is released.
  9    pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
 10    where
 11        F: FnOnce(&Connection) -> Result<R>,
 12    {
 13        let name = name.as_ref().to_owned();
 14        self.exec(format!("SAVEPOINT {}", &name))?;
 15        let result = f(self);
 16        match result {
 17            Ok(_) => {
 18                self.exec(format!("RELEASE {}", name))?;
 19            }
 20            Err(_) => {
 21                self.exec(format!("ROLLBACK TO {}", name))?;
 22                self.exec(format!("RELEASE {}", name))?;
 23            }
 24        }
 25        result
 26    }
 27
 28    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
 29    // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
 30    // point is released.
 31    pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
 32    where
 33        F: FnOnce(&Connection) -> Result<Option<R>>,
 34    {
 35        let name = name.as_ref().to_owned();
 36        self.exec(format!("SAVEPOINT {}", &name))?;
 37        let result = f(self);
 38        match result {
 39            Ok(Some(_)) => {
 40                self.exec(format!("RELEASE {}", name))?;
 41            }
 42            Ok(None) | Err(_) => {
 43                self.exec(format!("ROLLBACK TO {}", name))?;
 44                self.exec(format!("RELEASE {}", name))?;
 45            }
 46        }
 47        result
 48    }
 49}
 50
 51#[cfg(test)]
 52mod tests {
 53    use crate::connection::Connection;
 54    use anyhow::Result;
 55    use indoc::indoc;
 56
 57    #[test]
 58    fn test_nested_savepoints() -> Result<()> {
 59        let connection = Connection::open_memory("nested_savepoints");
 60
 61        connection
 62            .exec(indoc! {"
 63            CREATE TABLE text (
 64                text TEXT,
 65                idx INTEGER
 66            );"})
 67            .unwrap();
 68
 69        let save1_text = "test save1";
 70        let save2_text = "test save2";
 71
 72        connection.with_savepoint("first", |save1| {
 73            save1
 74                .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
 75                .with_bindings((save1_text, 1))?
 76                .exec()?;
 77
 78            assert!(save1
 79                .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
 80                    save2
 81                        .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
 82                        .with_bindings((save2_text, 2))?
 83                        .exec()?;
 84
 85                    assert_eq!(
 86                        save2
 87                            .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
 88                            .rows::<String>()?,
 89                        vec![save1_text, save2_text],
 90                    );
 91
 92                    anyhow::bail!("Failed second save point :(")
 93                })
 94                .err()
 95                .is_some());
 96
 97            assert_eq!(
 98                save1
 99                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
100                    .rows::<String>()?,
101                vec![save1_text],
102            );
103
104            save1.with_savepoint_rollback::<(), _>("second", |save2| {
105                save2
106                    .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
107                    .with_bindings((save2_text, 2))?
108                    .exec()?;
109
110                assert_eq!(
111                    save2
112                        .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
113                        .rows::<String>()?,
114                    vec![save1_text, save2_text],
115                );
116
117                Ok(None)
118            })?;
119
120            assert_eq!(
121                save1
122                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
123                    .rows::<String>()?,
124                vec![save1_text],
125            );
126
127            save1.with_savepoint_rollback("second", |save2| {
128                save2
129                    .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
130                    .with_bindings((save2_text, 2))?
131                    .exec()?;
132
133                assert_eq!(
134                    save2
135                        .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
136                        .rows::<String>()?,
137                    vec![save1_text, save2_text],
138                );
139
140                Ok(Some(()))
141            })?;
142
143            assert_eq!(
144                save1
145                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
146                    .rows::<String>()?,
147                vec![save1_text, save2_text],
148            );
149
150            Ok(())
151        })?;
152
153        assert_eq!(
154            connection
155                .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
156                .rows::<String>()?,
157            vec![save1_text, save2_text],
158        );
159
160        Ok(())
161    }
162}