1use anyhow::Result;
2
3use crate::connection::Connection;
4
5impl Connection {
6 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
7 // returns Err(_), the savepoint will be rolled back. Otherwise, the save
8 // point is released.
9 pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
10 where
11 F: FnOnce(&Connection) -> Result<R>,
12 {
13 let name = name.as_ref().to_owned();
14 self.exec(format!("SAVEPOINT {}", &name))?;
15 let result = f(self);
16 match result {
17 Ok(_) => {
18 self.exec(format!("RELEASE {}", name))?;
19 }
20 Err(_) => {
21 self.exec(format!("ROLLBACK TO {}", name))?;
22 self.exec(format!("RELEASE {}", name))?;
23 }
24 }
25 result
26 }
27
28 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
29 // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
30 // point is released.
31 pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
32 where
33 F: FnOnce(&Connection) -> Result<Option<R>>,
34 {
35 let name = name.as_ref().to_owned();
36 self.exec(format!("SAVEPOINT {}", &name))?;
37 let result = f(self);
38 match result {
39 Ok(Some(_)) => {
40 self.exec(format!("RELEASE {}", name))?;
41 }
42 Ok(None) | Err(_) => {
43 self.exec(format!("ROLLBACK TO {}", name))?;
44 self.exec(format!("RELEASE {}", name))?;
45 }
46 }
47 result
48 }
49}
50
51#[cfg(test)]
52mod tests {
53 use crate::connection::Connection;
54 use anyhow::Result;
55 use indoc::indoc;
56
57 #[test]
58 fn test_nested_savepoints() -> Result<()> {
59 let connection = Connection::open_memory("nested_savepoints");
60
61 connection
62 .exec(indoc! {"
63 CREATE TABLE text (
64 text TEXT,
65 idx INTEGER
66 );"})
67 .unwrap();
68
69 let save1_text = "test save1";
70 let save2_text = "test save2";
71
72 connection.with_savepoint("first", |save1| {
73 save1
74 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
75 .with_bindings((save1_text, 1))?
76 .exec()?;
77
78 assert!(save1
79 .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
80 save2
81 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
82 .with_bindings((save2_text, 2))?
83 .exec()?;
84
85 assert_eq!(
86 save2
87 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
88 .rows::<String>()?,
89 vec![save1_text, save2_text],
90 );
91
92 anyhow::bail!("Failed second save point :(")
93 })
94 .err()
95 .is_some());
96
97 assert_eq!(
98 save1
99 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
100 .rows::<String>()?,
101 vec![save1_text],
102 );
103
104 save1.with_savepoint_rollback::<(), _>("second", |save2| {
105 save2
106 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
107 .with_bindings((save2_text, 2))?
108 .exec()?;
109
110 assert_eq!(
111 save2
112 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
113 .rows::<String>()?,
114 vec![save1_text, save2_text],
115 );
116
117 Ok(None)
118 })?;
119
120 assert_eq!(
121 save1
122 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
123 .rows::<String>()?,
124 vec![save1_text],
125 );
126
127 save1.with_savepoint_rollback("second", |save2| {
128 save2
129 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
130 .with_bindings((save2_text, 2))?
131 .exec()?;
132
133 assert_eq!(
134 save2
135 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
136 .rows::<String>()?,
137 vec![save1_text, save2_text],
138 );
139
140 Ok(Some(()))
141 })?;
142
143 assert_eq!(
144 save1
145 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
146 .rows::<String>()?,
147 vec![save1_text, save2_text],
148 );
149
150 Ok(())
151 })?;
152
153 assert_eq!(
154 connection
155 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
156 .rows::<String>()?,
157 vec![save1_text, save2_text],
158 );
159
160 Ok(())
161 }
162}