savepoint.rs

  1use anyhow::Result;
  2use indoc::formatdoc;
  3
  4use crate::connection::Connection;
  5
  6impl Connection {
  7    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
  8    // returns Err(_), the savepoint will be rolled back. Otherwise, the save
  9    // point is released.
 10    pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
 11    where
 12        F: FnOnce() -> Result<R>,
 13    {
 14        let name = name.as_ref();
 15        self.exec(&format!("SAVEPOINT {name}"))?()?;
 16        let result = f();
 17        match result {
 18            Ok(_) => {
 19                self.exec(&format!("RELEASE {name}"))?()?;
 20            }
 21            Err(_) => {
 22                self.exec(&formatdoc! {"
 23                    ROLLBACK TO {name};
 24                    RELEASE {name}"})?()?;
 25            }
 26        }
 27        result
 28    }
 29
 30    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
 31    // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
 32    // point is released.
 33    pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
 34    where
 35        F: FnOnce() -> Result<Option<R>>,
 36    {
 37        let name = name.as_ref();
 38        self.exec(&format!("SAVEPOINT {name}"))?()?;
 39        let result = f();
 40        match result {
 41            Ok(Some(_)) => {
 42                self.exec(&format!("RELEASE {name}"))?()?;
 43            }
 44            Ok(None) | Err(_) => {
 45                self.exec(&formatdoc! {"
 46                    ROLLBACK TO {name};
 47                    RELEASE {name}"})?()?;
 48            }
 49        }
 50        result
 51    }
 52}
 53
 54#[cfg(test)]
 55mod tests {
 56    use crate::connection::Connection;
 57    use anyhow::Result;
 58    use indoc::indoc;
 59
 60    #[test]
 61    fn test_nested_savepoints() -> Result<()> {
 62        let connection = Connection::open_memory(Some("nested_savepoints"));
 63
 64        connection
 65            .exec(indoc! {"
 66            CREATE TABLE text (
 67                text TEXT,
 68                idx INTEGER
 69            );"})
 70            .unwrap()()
 71        .unwrap();
 72
 73        let save1_text = "test save1";
 74        let save2_text = "test save2";
 75
 76        connection.with_savepoint("first", || {
 77            connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
 78
 79            assert!(connection
 80                .with_savepoint("second", || -> Result<Option<()>, anyhow::Error> {
 81                    connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
 82                        save2_text, 2,
 83                    ))?;
 84
 85                    assert_eq!(
 86                        connection
 87                            .select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
 88                        )?,
 89                        vec![save1_text, save2_text],
 90                    );
 91
 92                    anyhow::bail!("Failed second save point :(")
 93                })
 94                .err()
 95                .is_some());
 96
 97            assert_eq!(
 98                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
 99                vec![save1_text],
100            );
101
102            connection.with_savepoint_rollback::<(), _>("second", || {
103                connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
104                    save2_text, 2,
105                ))?;
106
107                assert_eq!(
108                    connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
109                    vec![save1_text, save2_text],
110                );
111
112                Ok(None)
113            })?;
114
115            assert_eq!(
116                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
117                vec![save1_text],
118            );
119
120            connection.with_savepoint_rollback("second", || {
121                connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
122                    save2_text, 2,
123                ))?;
124
125                assert_eq!(
126                    connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
127                    vec![save1_text, save2_text],
128                );
129
130                Ok(Some(()))
131            })?;
132
133            assert_eq!(
134                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
135                vec![save1_text, save2_text],
136            );
137
138            Ok(())
139        })?;
140
141        assert_eq!(
142            connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
143            vec![save1_text, save2_text],
144        );
145
146        Ok(())
147    }
148}