1use anyhow::Result;
2use indoc::formatdoc;
3
4use crate::connection::Connection;
5
6impl Connection {
7 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
8 // returns Err(_), the savepoint will be rolled back. Otherwise, the save
9 // point is released.
10 pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
11 where
12 F: FnOnce() -> Result<R>,
13 {
14 let name = name.as_ref();
15 self.exec(&format!("SAVEPOINT {name}"))?()?;
16 let result = f();
17 match result {
18 Ok(_) => {
19 self.exec(&format!("RELEASE {name}"))?()?;
20 }
21 Err(_) => {
22 self.exec(&formatdoc! {"
23 ROLLBACK TO {name};
24 RELEASE {name}"})?()?;
25 }
26 }
27 result
28 }
29
30 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
31 // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
32 // point is released.
33 pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
34 where
35 F: FnOnce() -> Result<Option<R>>,
36 {
37 let name = name.as_ref();
38 self.exec(&format!("SAVEPOINT {name}"))?()?;
39 let result = f();
40 match result {
41 Ok(Some(_)) => {
42 self.exec(&format!("RELEASE {name}"))?()?;
43 }
44 Ok(None) | Err(_) => {
45 self.exec(&formatdoc! {"
46 ROLLBACK TO {name};
47 RELEASE {name}"})?()?;
48 }
49 }
50 result
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use crate::connection::Connection;
57 use anyhow::Result;
58 use indoc::indoc;
59
60 #[test]
61 fn test_nested_savepoints() -> Result<()> {
62 let connection = Connection::open_memory(Some("nested_savepoints"));
63
64 connection
65 .exec(indoc! {"
66 CREATE TABLE text (
67 text TEXT,
68 idx INTEGER
69 );"})
70 .unwrap()()
71 .unwrap();
72
73 let save1_text = "test save1";
74 let save2_text = "test save2";
75
76 connection.with_savepoint("first", || {
77 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
78
79 assert!(
80 connection
81 .with_savepoint("second", || -> anyhow::Result<Option<()>> {
82 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
83 save2_text, 2,
84 ))?;
85
86 assert_eq!(
87 connection
88 .select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
89 )?,
90 vec![save1_text, save2_text],
91 );
92
93 anyhow::bail!("Failed second save point :(")
94 })
95 .err()
96 .is_some()
97 );
98
99 assert_eq!(
100 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
101 vec![save1_text],
102 );
103
104 connection.with_savepoint_rollback::<(), _>("second", || {
105 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
106 save2_text, 2,
107 ))?;
108
109 assert_eq!(
110 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
111 vec![save1_text, save2_text],
112 );
113
114 Ok(None)
115 })?;
116
117 assert_eq!(
118 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
119 vec![save1_text],
120 );
121
122 connection.with_savepoint_rollback("second", || {
123 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
124 save2_text, 2,
125 ))?;
126
127 assert_eq!(
128 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
129 vec![save1_text, save2_text],
130 );
131
132 Ok(Some(()))
133 })?;
134
135 assert_eq!(
136 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
137 vec![save1_text, save2_text],
138 );
139
140 Ok(())
141 })?;
142
143 assert_eq!(
144 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
145 vec![save1_text, save2_text],
146 );
147
148 Ok(())
149 }
150}