1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4 path::Path,
5 ptr,
6};
7
8use anyhow::{anyhow, Result};
9use libsqlite3_sys::*;
10
11pub struct Connection {
12 pub(crate) sqlite3: *mut sqlite3,
13 persistent: bool,
14 pub(crate) write: bool,
15 _sqlite: PhantomData<sqlite3>,
16}
17unsafe impl Send for Connection {}
18
19impl Connection {
20 pub(crate) fn open(uri: &str, persistent: bool) -> Result<Self> {
21 let mut connection = Self {
22 sqlite3: 0 as *mut _,
23 persistent,
24 write: true,
25 _sqlite: PhantomData,
26 };
27
28 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
29 unsafe {
30 sqlite3_open_v2(
31 CString::new(uri)?.as_ptr(),
32 &mut connection.sqlite3,
33 flags,
34 0 as *const _,
35 );
36
37 // Turn on extended error codes
38 sqlite3_extended_result_codes(connection.sqlite3, 1);
39
40 connection.last_error()?;
41 }
42
43 Ok(connection)
44 }
45
46 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
47 /// instead.
48 pub fn open_file(uri: &str) -> Self {
49 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(Some(uri)))
50 }
51
52 pub fn open_memory(uri: Option<&str>) -> Self {
53 let in_memory_path = if let Some(uri) = uri {
54 format!("file:{}?mode=memory&cache=shared", uri)
55 } else {
56 ":memory:".to_string()
57 };
58
59 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
60 }
61
62 pub fn persistent(&self) -> bool {
63 self.persistent
64 }
65
66 pub fn can_write(&self) -> bool {
67 self.write
68 }
69
70 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
71 unsafe {
72 let backup = sqlite3_backup_init(
73 destination.sqlite3,
74 CString::new("main")?.as_ptr(),
75 self.sqlite3,
76 CString::new("main")?.as_ptr(),
77 );
78 sqlite3_backup_step(backup, -1);
79 sqlite3_backup_finish(backup);
80 destination.last_error()
81 }
82 }
83
84 pub fn backup_main_to(&self, destination: impl AsRef<Path>) -> Result<()> {
85 let destination = Self::open_file(destination.as_ref().to_string_lossy().as_ref());
86 self.backup_main(&destination)
87 }
88
89 pub fn sql_has_syntax_error(&self, sql: &str) -> Option<(String, usize)> {
90 let sql = CString::new(sql).unwrap();
91 let mut remaining_sql = sql.as_c_str();
92 let sql_start = remaining_sql.as_ptr();
93
94 unsafe {
95 while {
96 let remaining_sql_str = remaining_sql.to_str().unwrap().trim();
97 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
98 } {
99 let mut raw_statement = 0 as *mut sqlite3_stmt;
100 let mut remaining_sql_ptr = ptr::null();
101 sqlite3_prepare_v2(
102 self.sqlite3,
103 remaining_sql.as_ptr(),
104 -1,
105 &mut raw_statement,
106 &mut remaining_sql_ptr,
107 );
108
109 let res = sqlite3_errcode(self.sqlite3);
110 let offset = sqlite3_error_offset(self.sqlite3);
111 let message = sqlite3_errmsg(self.sqlite3);
112
113 sqlite3_finalize(raw_statement);
114
115 if res == 1 && offset >= 0 {
116 let err_msg =
117 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
118 .into_owned();
119 let sub_statement_correction =
120 remaining_sql.as_ptr() as usize - sql_start as usize;
121
122 return Some((err_msg, offset as usize + sub_statement_correction));
123 }
124 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
125 }
126 }
127 None
128 }
129
130 pub(crate) fn last_error(&self) -> Result<()> {
131 unsafe {
132 let code = sqlite3_errcode(self.sqlite3);
133 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
134 if NON_ERROR_CODES.contains(&code) {
135 return Ok(());
136 }
137
138 let message = sqlite3_errmsg(self.sqlite3);
139 let message = if message.is_null() {
140 None
141 } else {
142 Some(
143 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
144 .into_owned(),
145 )
146 };
147
148 Err(anyhow!(
149 "Sqlite call failed with code {} and message: {:?}",
150 code as isize,
151 message
152 ))
153 }
154 }
155}
156
157impl Drop for Connection {
158 fn drop(&mut self) {
159 unsafe { sqlite3_close(self.sqlite3) };
160 }
161}
162
163#[cfg(test)]
164mod test {
165 use anyhow::Result;
166 use indoc::indoc;
167
168 use crate::connection::Connection;
169
170 #[test]
171 fn string_round_trips() -> Result<()> {
172 let connection = Connection::open_memory(Some("string_round_trips"));
173 connection
174 .exec(indoc! {"
175 CREATE TABLE text (
176 text TEXT
177 );"})
178 .unwrap()()
179 .unwrap();
180
181 let text = "Some test text";
182
183 connection
184 .exec_bound("INSERT INTO text (text) VALUES (?);")
185 .unwrap()(text)
186 .unwrap();
187
188 assert_eq!(
189 connection.select_row("SELECT text FROM text;").unwrap()().unwrap(),
190 Some(text.to_string())
191 );
192
193 Ok(())
194 }
195
196 #[test]
197 fn tuple_round_trips() {
198 let connection = Connection::open_memory(Some("tuple_round_trips"));
199 connection
200 .exec(indoc! {"
201 CREATE TABLE test (
202 text TEXT,
203 integer INTEGER,
204 blob BLOB
205 );"})
206 .unwrap()()
207 .unwrap();
208
209 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
210 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
211
212 let mut insert = connection
213 .exec_bound::<(String, usize, Vec<u8>)>(
214 "INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)",
215 )
216 .unwrap();
217
218 insert(tuple1.clone()).unwrap();
219 insert(tuple2.clone()).unwrap();
220
221 assert_eq!(
222 connection
223 .select::<(String, usize, Vec<u8>)>("SELECT * FROM test")
224 .unwrap()()
225 .unwrap(),
226 vec![tuple1, tuple2]
227 );
228 }
229
230 #[test]
231 fn bool_round_trips() {
232 let connection = Connection::open_memory(Some("bool_round_trips"));
233 connection
234 .exec(indoc! {"
235 CREATE TABLE bools (
236 t INTEGER,
237 f INTEGER
238 );"})
239 .unwrap()()
240 .unwrap();
241
242 connection
243 .exec_bound("INSERT INTO bools(t, f) VALUES (?, ?)")
244 .unwrap()((true, false))
245 .unwrap();
246
247 assert_eq!(
248 connection
249 .select_row::<(bool, bool)>("SELECT * FROM bools;")
250 .unwrap()()
251 .unwrap(),
252 Some((true, false))
253 );
254 }
255
256 #[test]
257 fn backup_works() {
258 let connection1 = Connection::open_memory(Some("backup_works"));
259 connection1
260 .exec(indoc! {"
261 CREATE TABLE blobs (
262 data BLOB
263 );"})
264 .unwrap()()
265 .unwrap();
266 let blob = vec![0, 1, 2, 4, 8, 16, 32, 64];
267 connection1
268 .exec_bound::<Vec<u8>>("INSERT INTO blobs (data) VALUES (?);")
269 .unwrap()(blob.clone())
270 .unwrap();
271
272 // Backup connection1 to connection2
273 let connection2 = Connection::open_memory(Some("backup_works_other"));
274 connection1.backup_main(&connection2).unwrap();
275
276 // Delete the added blob and verify its deleted on the other side
277 let read_blobs = connection1
278 .select::<Vec<u8>>("SELECT * FROM blobs;")
279 .unwrap()()
280 .unwrap();
281 assert_eq!(read_blobs, vec![blob]);
282 }
283
284 #[test]
285 fn multi_step_statement_works() {
286 let connection = Connection::open_memory(Some("multi_step_statement_works"));
287
288 connection
289 .exec(indoc! {"
290 CREATE TABLE test (
291 col INTEGER
292 )"})
293 .unwrap()()
294 .unwrap();
295
296 connection
297 .exec(indoc! {"
298 INSERT INTO test(col) VALUES (2)"})
299 .unwrap()()
300 .unwrap();
301
302 assert_eq!(
303 connection
304 .select_row::<usize>("SELECT * FROM test")
305 .unwrap()()
306 .unwrap(),
307 Some(2)
308 );
309 }
310
311 #[test]
312 fn test_sql_has_syntax_errors() {
313 let connection = Connection::open_memory(Some("test_sql_has_syntax_errors"));
314 let first_stmt =
315 "CREATE TABLE kv_store(key TEXT PRIMARY KEY, value TEXT NOT NULL) STRICT ;";
316 let second_stmt = "SELECT FROM";
317
318 let second_offset = connection.sql_has_syntax_error(second_stmt).unwrap().1;
319
320 let res = connection
321 .sql_has_syntax_error(&format!("{}\n{}", first_stmt, second_stmt))
322 .map(|(_, offset)| offset);
323
324 assert_eq!(res, Some(first_stmt.len() + second_offset + 1));
325 }
326}