1use anyhow::Result;
2use indoc::formatdoc;
3
4use crate::connection::Connection;
5
6impl Connection {
7 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
8 // returns Err(_), the savepoint will be rolled back. Otherwise, the save
9 // point is released.
10 pub fn with_savepoint<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<R>
11 where
12 F: FnOnce() -> Result<R>,
13 {
14 let name = name.as_ref();
15 self.exec(&format!("SAVEPOINT {name}"))?()?;
16 let result = f();
17 match result {
18 Ok(_) => {
19 self.exec(&format!("RELEASE {name}"))?()?;
20 }
21 Err(_) => {
22 self.exec(&formatdoc! {"
23 ROLLBACK TO {name};
24 RELEASE {name}"})?()?;
25 }
26 }
27 result
28 }
29
30 // Run a set of commands within the context of a `SAVEPOINT name`. If the callback
31 // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save
32 // point is released.
33 pub fn with_savepoint_rollback<R, F>(&self, name: impl AsRef<str>, f: F) -> Result<Option<R>>
34 where
35 F: FnOnce() -> Result<Option<R>>,
36 {
37 let name = name.as_ref();
38 self.exec(&format!("SAVEPOINT {name}"))?()?;
39 let result = f();
40 match result {
41 Ok(Some(_)) => {
42 self.exec(&format!("RELEASE {name}"))?()?;
43 }
44 Ok(None) | Err(_) => {
45 self.exec(&formatdoc! {"
46 ROLLBACK TO {name};
47 RELEASE {name}"})?()?;
48 }
49 }
50 result
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use crate::connection::Connection;
57 use anyhow::Result;
58 use indoc::indoc;
59
60 #[test]
61 fn test_nested_savepoints() -> Result<()> {
62 let connection = Connection::open_memory(Some("nested_savepoints"));
63
64 connection
65 .exec(indoc! {"
66 CREATE TABLE text (
67 text TEXT,
68 idx INTEGER
69 );"})
70 .unwrap()()
71 .unwrap();
72
73 let save1_text = "test save1";
74 let save2_text = "test save2";
75
76 connection.with_savepoint("first", || {
77 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((save1_text, 1))?;
78
79 assert!(connection
80 .with_savepoint("second", || -> Result<Option<()>, anyhow::Error> {
81 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
82 save2_text, 2,
83 ))?;
84
85 assert_eq!(
86 connection
87 .select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?(
88 )?,
89 vec![save1_text, save2_text],
90 );
91
92 anyhow::bail!("Failed second save point :(")
93 })
94 .err()
95 .is_some());
96
97 assert_eq!(
98 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
99 vec![save1_text],
100 );
101
102 connection.with_savepoint_rollback::<(), _>("second", || {
103 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
104 save2_text, 2,
105 ))?;
106
107 assert_eq!(
108 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
109 vec![save1_text, save2_text],
110 );
111
112 Ok(None)
113 })?;
114
115 assert_eq!(
116 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
117 vec![save1_text],
118 );
119
120 connection.with_savepoint_rollback("second", || {
121 connection.exec_bound("INSERT INTO text(text, idx) VALUES (?, ?)")?((
122 save2_text, 2,
123 ))?;
124
125 assert_eq!(
126 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
127 vec![save1_text, save2_text],
128 );
129
130 Ok(Some(()))
131 })?;
132
133 assert_eq!(
134 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
135 vec![save1_text, save2_text],
136 );
137
138 Ok(())
139 })?;
140
141 assert_eq!(
142 connection.select::<String>("SELECT text FROM text ORDER BY text.idx ASC")?()?,
143 vec![save1_text, save2_text],
144 );
145
146 Ok(())
147 }
148}