1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4};
5
6use anyhow::{anyhow, Result};
7use libsqlite3_sys::*;
8
9pub struct Connection {
10 pub(crate) sqlite3: *mut sqlite3,
11 persistent: bool,
12 phantom: PhantomData<sqlite3>,
13}
14unsafe impl Send for Connection {}
15
16impl Connection {
17 fn open(uri: &str, persistent: bool) -> Result<Self> {
18 let mut connection = Self {
19 sqlite3: 0 as *mut _,
20 persistent,
21 phantom: PhantomData,
22 };
23
24 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
25 unsafe {
26 sqlite3_open_v2(
27 CString::new(uri)?.as_ptr(),
28 &mut connection.sqlite3,
29 flags,
30 0 as *const _,
31 );
32
33 // Turn on extended error codes
34 sqlite3_extended_result_codes(connection.sqlite3, 1);
35
36 connection.last_error()?;
37 }
38
39 Ok(connection)
40 }
41
42 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
43 /// instead.
44 pub fn open_file(uri: &str) -> Self {
45 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
46 }
47
48 pub fn open_memory(uri: Option<&str>) -> Self {
49 let in_memory_path = if let Some(uri) = uri {
50 format!("file:{}?mode=memory&cache=shared", uri)
51 } else {
52 ":memory:".to_string()
53 };
54
55 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
56 }
57
58 pub fn persistent(&self) -> bool {
59 self.persistent
60 }
61
62 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
63 unsafe {
64 let backup = sqlite3_backup_init(
65 destination.sqlite3,
66 CString::new("main")?.as_ptr(),
67 self.sqlite3,
68 CString::new("main")?.as_ptr(),
69 );
70 sqlite3_backup_step(backup, -1);
71 sqlite3_backup_finish(backup);
72 destination.last_error()
73 }
74 }
75
76 pub(crate) fn last_error(&self) -> Result<()> {
77 unsafe {
78 let code = sqlite3_errcode(self.sqlite3);
79 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
80 if NON_ERROR_CODES.contains(&code) {
81 return Ok(());
82 }
83
84 let message = sqlite3_errmsg(self.sqlite3);
85 let message = if message.is_null() {
86 None
87 } else {
88 Some(
89 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
90 .into_owned(),
91 )
92 };
93
94 Err(anyhow!(
95 "Sqlite call failed with code {} and message: {:?}",
96 code as isize,
97 message
98 ))
99 }
100 }
101}
102
103impl Drop for Connection {
104 fn drop(&mut self) {
105 unsafe { sqlite3_close(self.sqlite3) };
106 }
107}
108
109#[cfg(test)]
110mod test {
111 use anyhow::Result;
112 use indoc::indoc;
113
114 use crate::connection::Connection;
115
116 #[test]
117 fn string_round_trips() -> Result<()> {
118 let connection = Connection::open_memory(Some("string_round_trips"));
119 connection
120 .exec(indoc! {"
121 CREATE TABLE text (
122 text TEXT
123 );"})
124 .unwrap()()
125 .unwrap();
126
127 let text = "Some test text";
128
129 connection
130 .exec_bound("INSERT INTO text (text) VALUES (?);")
131 .unwrap()(text)
132 .unwrap();
133
134 assert_eq!(
135 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
136 Some(text.to_string())
137 );
138
139 Ok(())
140 }
141
142 #[test]
143 fn tuple_round_trips() {
144 let connection = Connection::open_memory(Some("tuple_round_trips"));
145 connection
146 .exec(indoc! {"
147 CREATE TABLE test (
148 text TEXT,
149 integer INTEGER,
150 blob BLOB
151 );"})
152 .unwrap()()
153 .unwrap();
154
155 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
156 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
157
158 let mut insert = connection
159 .exec_bound::<(String, usize, Vec<u8>)>(
160 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
161 )
162 .unwrap();
163
164 insert(tuple1.clone()).unwrap();
165 insert(tuple2.clone()).unwrap();
166
167 assert_eq!(
168 connection
169 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
170 .unwrap()()
171 .unwrap(),
172 vec![tuple1, tuple2]
173 );
174 }
175
176 #[test]
177 fn bool_round_trips() {
178 let connection = Connection::open_memory(Some("bool_round_trips"));
179 connection
180 .exec(indoc! {"
181 CREATE TABLE bools (
182 t INTEGER,
183 f INTEGER
184 );"})
185 .unwrap()()
186 .unwrap();
187
188 connection
189 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
190 .unwrap()((true, false))
191 .unwrap();
192
193 assert_eq!(
194 connection
195 .select_row::<(bool, bool)>("SELECT * FROM bools;")
196 .unwrap()()
197 .unwrap(),
198 Some((true, false))
199 );
200 }
201
202 #[test]
203 fn backup_works() {
204 let connection1 = Connection::open_memory(Some("backup_works"));
205 connection1
206 .exec(indoc! {"
207 CREATE TABLE blobs (
208 data BLOB
209 );"})
210 .unwrap()()
211 .unwrap();
212 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
213 connection1
214 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
215 .unwrap()(blob.clone())
216 .unwrap();
217
218 // Backup connection1 to connection2
219 let connection2 = Connection::open_memory(Some("backup_works_other"));
220 connection1.backup_main(&connection2).unwrap();
221
222 // Delete the added blob and verify its deleted on the other side
223 let read_blobs = connection1
224 .select::<Vec<u8>>("SELECT * FROM blobs;")
225 .unwrap()()
226 .unwrap();
227 assert_eq!(read_blobs, vec![blob]);
228 }
229
230 #[test]
231 fn multi_step_statement_works() {
232 let connection = Connection::open_memory(Some("multi_step_statement_works"));
233
234 connection
235 .exec(indoc! {"
236 CREATE TABLE test (
237 col INTEGER
238 )"})
239 .unwrap()()
240 .unwrap();
241
242 connection
243 .exec(indoc! {"
244 INSERT INTO test(col) VALUES (2)"})
245 .unwrap()()
246 .unwrap();
247
248 assert_eq!(
249 connection
250 .select_row::<usize>("SELECt * FROM test")
251 .unwrap()()
252 .unwrap(),
253 Some(2)
254 );
255 }
256}