1use anyhow::Result;
2
3use crate::connection::Connection;
4
5impl Connection {
6 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
7 // returns Err(_), the savepoint will be rolled back. Otherwise, the save
8 // point is released.
9 pub fn with_savepoint<R, F>(&mut self, name: impl AsRef<str>, f: F) -> Result<R>
10 where
11 F: FnOnce(&mut Connection) -> Result<R>,
12 {
13 let name = name.as_ref().to_owned();
14 self.exec(format!("SAVEPOINT {}", &name))?;
15 let result = f(self);
16 match result {
17 Ok(_) => {
18 self.exec(format!("RELEASE {}", name))?;
19 }
20 Err(_) => {
21 self.exec(format!("ROLLBACK TO {}", name))?;
22 self.exec(format!("RELEASE {}", name))?;
23 }
24 }
25 result
26 }
27
28 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
29 // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
30 // point is released.
31 pub fn with_savepoint_rollback<R, F>(
32 &mut self,
33 name: impl AsRef<str>,
34 f: F,
35 ) -> Result<Option<R>>
36 where
37 F: FnOnce(&mut Connection) -> Result<Option<R>>,
38 {
39 let name = name.as_ref().to_owned();
40 self.exec(format!("SAVEPOINT {}", &name))?;
41 let result = f(self);
42 match result {
43 Ok(Some(_)) => {
44 self.exec(format!("RELEASE {}", name))?;
45 }
46 Ok(None) | Err(_) => {
47 self.exec(format!("ROLLBACK TO {}", name))?;
48 self.exec(format!("RELEASE {}", name))?;
49 }
50 }
51 result
52 }
53}
54
55#[cfg(test)]
56mod tests {
57 use crate::connection::Connection;
58 use anyhow::Result;
59 use indoc::indoc;
60
61 #[test]
62 fn test_nested_savepoints() -> Result<()> {
63 let mut connection = Connection::open_memory("nested_savepoints");
64
65 connection
66 .exec(indoc! {"
67 CREATE TABLE text (
68 text TEXT,
69 idx INTEGER
70 );"})
71 .unwrap();
72
73 let save1_text = "test save1";
74 let save2_text = "test save2";
75
76 connection.with_savepoint("first", |save1| {
77 save1
78 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
79 .with_bindings((save1_text, 1))?
80 .exec()?;
81
82 assert!(save1
83 .with_savepoint("second", |save2| -> Result<Option<()>, anyhow::Error> {
84 save2
85 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
86 .with_bindings((save2_text, 2))?
87 .exec()?;
88
89 assert_eq!(
90 save2
91 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
92 .rows::<String>()?,
93 vec![save1_text, save2_text],
94 );
95
96 anyhow::bail!("Failed second save point :(")
97 })
98 .err()
99 .is_some());
100
101 assert_eq!(
102 save1
103 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
104 .rows::<String>()?,
105 vec![save1_text],
106 );
107
108 save1.with_savepoint_rollback::<(), _>("second", |save2| {
109 save2
110 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
111 .with_bindings((save2_text, 2))?
112 .exec()?;
113
114 assert_eq!(
115 save2
116 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
117 .rows::<String>()?,
118 vec![save1_text, save2_text],
119 );
120
121 Ok(None)
122 })?;
123
124 assert_eq!(
125 save1
126 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
127 .rows::<String>()?,
128 vec![save1_text],
129 );
130
131 save1.with_savepoint_rollback("second", |save2| {
132 save2
133 .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")?
134 .with_bindings((save2_text, 2))?
135 .exec()?;
136
137 assert_eq!(
138 save2
139 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
140 .rows::<String>()?,
141 vec![save1_text, save2_text],
142 );
143
144 Ok(Some(()))
145 })?;
146
147 assert_eq!(
148 save1
149 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
150 .rows::<String>()?,
151 vec![save1_text, save2_text],
152 );
153
154 Ok(())
155 })?;
156
157 assert_eq!(
158 connection
159 .prepare("SELECT text FROM text ORDER BY text.idx ASC")?
160 .rows::<String>()?,
161 vec![save1_text, save2_text],
162 );
163
164 Ok(())
165 }
166}