savepoint.rs

  1use anyhow::Result;
  2
  3use crate::connection::Connection;
  4
  5impl Connection {
  6    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
  7    // returns Err(_), the savepoint will be rolled back. Otherwise, the save
  8    // point is released.
  9    pub fn with_savepoint<R, F>(&mut self, name: impl AsRef<str>, f: F) -> Result<R>
 10    where
 11        F: FnOnce(&mut Connection) -> Result<R>,
 12    {
 13        let name = name.as_ref().to_owned();
 14        self.exec(format!("SAVEPOINT {}", &name))?;
 15        let result = f(self);
 16        match result {
 17            Ok(_) => {
 18                self.exec(format!("RELEASE {}", name))?;
 19            }
 20            Err(_) => {
 21                self.exec(format!("ROLLBACK TO {}", name))?;
 22                self.exec(format!("RELEASE {}", name))?;
 23            }
 24        }
 25        result
 26    }
 27
 28    // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
 29    // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
 30    // point is released.
 31    pub fn with_savepoint_rollback<R, F>(
 32        &mut self,
 33        name: impl AsRef<str>,
 34        f: F,
 35    ) -> Result<Option<R>>
 36    where
 37        F: FnOnce(&mut Connection) -> Result<Option<R>>,
 38    {
 39        let name = name.as_ref().to_owned();
 40        self.exec(format!("SAVEPOINT {}", &name))?;
 41        let result = f(self);
 42        match result {
 43            Ok(Some(_)) => {
 44                self.exec(format!("RELEASE {}", name))?;
 45            }
 46            Ok(None) | Err(_) => {
 47                self.exec(format!("ROLLBACK TO {}", name))?;
 48                self.exec(format!("RELEASE {}", name))?;
 49            }
 50        }
 51        result
 52    }
 53}
 54
 55#[cfg(test)]
 56mod tests {
 57    use crate::connection::Connection;
 58    use anyhow::Result;
 59    use indoc::indoc;
 60
 61    #[test]
 62    fn test_nested_savepoints() -> Result<()> {
 63        let mut connection = Connection::open_memory("nested_savepoints");
 64
 65        connection
 66            .exec(indoc! {"
 67            CREATE TABLE text (
 68                text TEXT,
 69                idx INTEGER
 70            );"})
 71            .unwrap();
 72
 73        let save1_text = "test save1";
 74        let save2_text = "test save2";
 75
 76        connection.with_savepoint("first", |save1| {
 77            save1
 78                .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
 79                .with_bindings((save1_text, 1))?
 80                .exec()?;
 81
 82            assert!(save1
 83                .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
 84                    save2
 85                        .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
 86                        .with_bindings((save2_text, 2))?
 87                        .exec()?;
 88
 89                    assert_eq!(
 90                        save2
 91                            .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
 92                            .rows::<String>()?,
 93                        vec![save1_text, save2_text],
 94                    );
 95
 96                    anyhow::bail!("Failed second save point :(")
 97                })
 98                .err()
 99                .is_some());
100
101            assert_eq!(
102                save1
103                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
104                    .rows::<String>()?,
105                vec![save1_text],
106            );
107
108            save1.with_savepoint_rollback::<(), _>("second", |save2| {
109                save2
110                    .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
111                    .with_bindings((save2_text, 2))?
112                    .exec()?;
113
114                assert_eq!(
115                    save2
116                        .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
117                        .rows::<String>()?,
118                    vec![save1_text, save2_text],
119                );
120
121                Ok(None)
122            })?;
123
124            assert_eq!(
125                save1
126                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
127                    .rows::<String>()?,
128                vec![save1_text],
129            );
130
131            save1.with_savepoint_rollback("second", |save2| {
132                save2
133                    .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
134                    .with_bindings((save2_text, 2))?
135                    .exec()?;
136
137                assert_eq!(
138                    save2
139                        .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
140                        .rows::<String>()?,
141                    vec![save1_text, save2_text],
142                );
143
144                Ok(Some(()))
145            })?;
146
147            assert_eq!(
148                save1
149                    .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
150                    .rows::<String>()?,
151                vec![save1_text, save2_text],
152            );
153
154            Ok(())
155        })?;
156
157        assert_eq!(
158            connection
159                .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
160                .rows::<String>()?,
161            vec![save1_text, save2_text],
162        );
163
164        Ok(())
165    }
166}