savepoint.rs

  1use anyhow::Result;
  2use indoc::formatdoc;
  3
  4use crate::connection::Connection;
  5
  6impl Connection {
  7    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
  8    // returns Err(_), the savepoint will be rolled back. Otherwise, the save
  9    // point is released.
 10    pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
 11    where
 12        F: FnOnce() -> Result<R>,
 13    {
 14        let name = name.as_ref();
 15        self.exec(&format!("SAVEPOINT {name}"))?()?;
 16        let result = f();
 17        match result {
 18            Ok(_) => {
 19                self.exec(&format!("RELEASE {name}"))?()?;
 20            }
 21            Err(_) => {
 22                self.exec(&formatdoc! {"
 23                    ROLLBACK TO {name};
 24                    RELEASE {name}"})?()?;
 25            }
 26        }
 27        result
 28    }
 29
 30    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
 31    // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
 32    // point is released.
 33    pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
 34    where
 35        F: FnOnce() -> Result<Option<R>>,
 36    {
 37        let name = name.as_ref();
 38        self.exec(&format!("SAVEPOINT {name}"))?()?;
 39        let result = f();
 40        match result {
 41            Ok(Some(_)) => {
 42                self.exec(&format!("RELEASE {name}"))?()?;
 43            }
 44            Ok(None) | Err(_) => {
 45                self.exec(&formatdoc! {"
 46                    ROLLBACK TO {name};
 47                    RELEASE {name}"})?()?;
 48            }
 49        }
 50        result
 51    }
 52}
 53
 54#[cfg(test)]
 55mod tests {
 56    use crate::connection::Connection;
 57    use anyhow::Result;
 58    use indoc::indoc;
 59
 60    #[test]
 61    fn test_nested_savepoints() -> Result<()> {
 62        let connection = Connection::open_memory(Some("nested_savepoints"));
 63
 64        connection
 65            .exec(indoc! {"
 66            CREATE TABLE text (
 67                text TEXT,
 68                idx INTEGER
 69            );"})
 70            .unwrap()()
 71        .unwrap();
 72
 73        let save1_text = "test save1";
 74        let save2_text = "test save2";
 75
 76        connection.with_savepoint("first", || {
 77            connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
 78
 79            assert!(
 80                connection
 81                    .with_savepoint("second", || -> anyhow::Result<Option<()>> {
 82                        connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
 83                            save2_text, 2,
 84                        ))?;
 85
 86                        assert_eq!(
 87                            connection
 88                                .select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
 89                            )?,
 90                            vec![save1_text, save2_text],
 91                        );
 92
 93                        anyhow::bail!("Failed second save point :(")
 94                    })
 95                    .err()
 96                    .is_some()
 97            );
 98
 99            assert_eq!(
100                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
101                vec![save1_text],
102            );
103
104            connection.with_savepoint_rollback::<(), _>("second", || {
105                connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
106                    save2_text, 2,
107                ))?;
108
109                assert_eq!(
110                    connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
111                    vec![save1_text, save2_text],
112                );
113
114                Ok(None)
115            })?;
116
117            assert_eq!(
118                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
119                vec![save1_text],
120            );
121
122            connection.with_savepoint_rollback("second", || {
123                connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
124                    save2_text, 2,
125                ))?;
126
127                assert_eq!(
128                    connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
129                    vec![save1_text, save2_text],
130                );
131
132                Ok(Some(()))
133            })?;
134
135            assert_eq!(
136                connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
137                vec![save1_text, save2_text],
138            );
139
140            Ok(())
141        })?;
142
143        assert_eq!(
144            connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
145            vec![save1_text, save2_text],
146        );
147
148        Ok(())
149    }
150}