1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4 path::Path,
5};
6
7use anyhow::{anyhow, Result};
8use libsqlite3_sys::*;
9
10pub struct Connection {
11 pub(crate) sqlite3: *mut sqlite3,
12 persistent: bool,
13 pub(crate) write: bool,
14 _sqlite: PhantomData<sqlite3>,
15}
16unsafe impl Send for Connection {}
17
18impl Connection {
19 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
20 let mut connection = Self {
21 sqlite3: 0 as *mut _,
22 persistent,
23 write: true,
24 _sqlite: PhantomData,
25 };
26
27 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
28 unsafe {
29 sqlite3_open_v2(
30 CString::new(uri)?.as_ptr(),
31 &mut connection.sqlite3,
32 flags,
33 0 as *const _,
34 );
35
36 // Turn on extended error codes
37 sqlite3_extended_result_codes(connection.sqlite3, 1);
38
39 connection.last_error()?;
40 }
41
42 Ok(connection)
43 }
44
45 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
46 /// instead.
47 pub fn open_file(uri: &str) -> Self {
48 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
49 }
50
51 pub fn open_memory(uri: Option<&str>) -> Self {
52 let in_memory_path = if let Some(uri) = uri {
53 format!("file:{}?mode=memory&cache=shared", uri)
54 } else {
55 ":memory:".to_string()
56 };
57
58 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
59 }
60
61 pub fn persistent(&self) -> bool {
62 self.persistent
63 }
64
65 pub fn can_write(&self) -> bool {
66 self.write
67 }
68
69 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
70 unsafe {
71 let backup = sqlite3_backup_init(
72 destination.sqlite3,
73 CString::new("main")?.as_ptr(),
74 self.sqlite3,
75 CString::new("main")?.as_ptr(),
76 );
77 sqlite3_backup_step(backup, -1);
78 sqlite3_backup_finish(backup);
79 destination.last_error()
80 }
81 }
82
83 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
84 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
85 self.backup_main(&destination)
86 }
87
88 pub(crate) fn last_error(&self) -> Result<()> {
89 unsafe {
90 let code = sqlite3_errcode(self.sqlite3);
91 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
92 if NON_ERROR_CODES.contains(&code) {
93 return Ok(());
94 }
95
96 let message = sqlite3_errmsg(self.sqlite3);
97 let message = if message.is_null() {
98 None
99 } else {
100 Some(
101 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
102 .into_owned(),
103 )
104 };
105
106 Err(anyhow!(
107 "Sqlite call failed with code {} and message: {:?}",
108 code as isize,
109 message
110 ))
111 }
112 }
113}
114
115impl Drop for Connection {
116 fn drop(&mut self) {
117 unsafe { sqlite3_close(self.sqlite3) };
118 }
119}
120
121#[cfg(test)]
122mod test {
123 use anyhow::Result;
124 use indoc::indoc;
125
126 use crate::connection::Connection;
127
128 #[test]
129 fn string_round_trips() -> Result<()> {
130 let connection = Connection::open_memory(Some("string_round_trips"));
131 connection
132 .exec(indoc! {"
133 CREATE TABLE text (
134 text TEXT
135 );"})
136 .unwrap()()
137 .unwrap();
138
139 let text = "Some test text";
140
141 connection
142 .exec_bound("INSERT INTO text (text) VALUES (?);")
143 .unwrap()(text)
144 .unwrap();
145
146 assert_eq!(
147 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
148 Some(text.to_string())
149 );
150
151 Ok(())
152 }
153
154 #[test]
155 fn tuple_round_trips() {
156 let connection = Connection::open_memory(Some("tuple_round_trips"));
157 connection
158 .exec(indoc! {"
159 CREATE TABLE test (
160 text TEXT,
161 integer INTEGER,
162 blob BLOB
163 );"})
164 .unwrap()()
165 .unwrap();
166
167 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
168 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
169
170 let mut insert = connection
171 .exec_bound::<(String, usize, Vec<u8>)>(
172 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
173 )
174 .unwrap();
175
176 insert(tuple1.clone()).unwrap();
177 insert(tuple2.clone()).unwrap();
178
179 assert_eq!(
180 connection
181 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
182 .unwrap()()
183 .unwrap(),
184 vec![tuple1, tuple2]
185 );
186 }
187
188 #[test]
189 fn bool_round_trips() {
190 let connection = Connection::open_memory(Some("bool_round_trips"));
191 connection
192 .exec(indoc! {"
193 CREATE TABLE bools (
194 t INTEGER,
195 f INTEGER
196 );"})
197 .unwrap()()
198 .unwrap();
199
200 connection
201 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
202 .unwrap()((true, false))
203 .unwrap();
204
205 assert_eq!(
206 connection
207 .select_row::<(bool, bool)>("SELECT * FROM bools;")
208 .unwrap()()
209 .unwrap(),
210 Some((true, false))
211 );
212 }
213
214 #[test]
215 fn backup_works() {
216 let connection1 = Connection::open_memory(Some("backup_works"));
217 connection1
218 .exec(indoc! {"
219 CREATE TABLE blobs (
220 data BLOB
221 );"})
222 .unwrap()()
223 .unwrap();
224 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
225 connection1
226 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
227 .unwrap()(blob.clone())
228 .unwrap();
229
230 // Backup connection1 to connection2
231 let connection2 = Connection::open_memory(Some("backup_works_other"));
232 connection1.backup_main(&connection2).unwrap();
233
234 // Delete the added blob and verify its deleted on the other side
235 let read_blobs = connection1
236 .select::<Vec<u8>>("SELECT * FROM blobs;")
237 .unwrap()()
238 .unwrap();
239 assert_eq!(read_blobs, vec![blob]);
240 }
241
242 #[test]
243 fn multi_step_statement_works() {
244 let connection = Connection::open_memory(Some("multi_step_statement_works"));
245
246 connection
247 .exec(indoc! {"
248 CREATE TABLE test (
249 col INTEGER
250 )"})
251 .unwrap()()
252 .unwrap();
253
254 connection
255 .exec(indoc! {"
256 INSERT INTO test(col) VALUES (2)"})
257 .unwrap()()
258 .unwrap();
259
260 assert_eq!(
261 connection
262 .select_row::<usize>("SELECt * FROM test")
263 .unwrap()()
264 .unwrap(),
265 Some(2)
266 );
267 }
268}