1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4};
5
6use anyhow::{anyhow, Result};
7use libsqlite3_sys::*;
8
9use crate::statement::Statement;
10
11pub struct Connection {
12 pub(crate) sqlite3: *mut sqlite3,
13 persistent: bool,
14 phantom: PhantomData<sqlite3>,
15}
16unsafe impl Send for Connection {}
17
18impl Connection {
19 fn open(uri: &str, persistent: bool) -> Result<Self> {
20 let mut connection = Self {
21 sqlite3: 0 as *mut _,
22 persistent,
23 phantom: PhantomData,
24 };
25
26 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
27 unsafe {
28 sqlite3_open_v2(
29 CString::new(uri)?.as_ptr(),
30 &mut connection.sqlite3,
31 flags,
32 0 as *const _,
33 );
34
35 // Turn on extended error codes
36 sqlite3_extended_result_codes(connection.sqlite3, 1);
37
38 connection.last_error()?;
39 }
40
41 Ok(connection)
42 }
43
44 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
45 /// instead.
46 pub fn open_file(uri: &str) -> Self {
47 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri))
48 }
49
50 pub fn open_memory(uri: &str) -> Self {
51 let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
52 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
53 }
54
55 pub fn persistent(&self) -> bool {
56 self.persistent
57 }
58
59 pub(crate) fn last_insert_id(&self) -> i64 {
60 unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
61 }
62
63 pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
64 self.exec(query)?;
65 Ok(self.last_insert_id())
66 }
67
68 pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
69 unsafe {
70 sqlite3_exec(
71 self.sqlite3,
72 CString::new(query.as_ref())?.as_ptr(),
73 None,
74 0 as *mut _,
75 0 as *mut _,
76 );
77 sqlite3_errcode(self.sqlite3);
78 self.last_error()?;
79 }
80 Ok(())
81 }
82
83 pub fn prepare<T: AsRef<str>>(&self, query: T) -> Result<Statement> {
84 Statement::prepare(&self, query)
85 }
86
87 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
88 unsafe {
89 let backup = sqlite3_backup_init(
90 destination.sqlite3,
91 CString::new("main")?.as_ptr(),
92 self.sqlite3,
93 CString::new("main")?.as_ptr(),
94 );
95 sqlite3_backup_step(backup, -1);
96 sqlite3_backup_finish(backup);
97 destination.last_error()
98 }
99 }
100
101 pub(crate) fn last_error(&self) -> Result<()> {
102 unsafe { error_to_result(sqlite3_errcode(self.sqlite3)) }
103 }
104}
105
106impl Drop for Connection {
107 fn drop(&mut self) {
108 unsafe { sqlite3_close(self.sqlite3) };
109 }
110}
111
112pub(crate) fn error_to_result(code: std::os::raw::c_int) -> Result<()> {
113 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
114 unsafe {
115 if NON_ERROR_CODES.contains(&code) {
116 return Ok(());
117 }
118
119 let message = sqlite3_errstr(code);
120 let message = if message.is_null() {
121 None
122 } else {
123 Some(
124 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
125 .into_owned(),
126 )
127 };
128
129 Err(anyhow!(
130 "Sqlite call failed with code {} and message: {:?}",
131 code as isize,
132 message
133 ))
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use anyhow::Result;
140 use indoc::indoc;
141
142 use crate::{connection::Connection, migrations::Migration};
143
144 #[test]
145 fn string_round_trips() -> Result<()> {
146 let connection = Connection::open_memory("string_round_trips");
147 connection
148 .exec(indoc! {"
149 CREATE TABLE text (
150 text TEXT
151 );"})
152 .unwrap();
153
154 let text = "Some test text";
155
156 connection
157 .prepare("INSERT INTO text (text) VALUES (?);")
158 .unwrap()
159 .with_bindings(text)
160 .unwrap()
161 .exec()
162 .unwrap();
163
164 assert_eq!(
165 &connection
166 .prepare("SELECT text FROM text;")
167 .unwrap()
168 .row::<String>()
169 .unwrap(),
170 text
171 );
172
173 Ok(())
174 }
175
176 #[test]
177 fn tuple_round_trips() {
178 let connection = Connection::open_memory("tuple_round_trips");
179 connection
180 .exec(indoc! {"
181 CREATE TABLE test (
182 text TEXT,
183 integer INTEGER,
184 blob BLOB
185 );"})
186 .unwrap();
187
188 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
189 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
190
191 let mut insert = connection
192 .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
193 .unwrap();
194
195 insert
196 .with_bindings(tuple1.clone())
197 .unwrap()
198 .exec()
199 .unwrap();
200 insert
201 .with_bindings(tuple2.clone())
202 .unwrap()
203 .exec()
204 .unwrap();
205
206 assert_eq!(
207 connection
208 .prepare("SELECT * FROM test")
209 .unwrap()
210 .rows::<(String, usize, Vec<u8>)>()
211 .unwrap(),
212 vec![tuple1, tuple2]
213 );
214 }
215
216 #[test]
217 fn backup_works() {
218 let connection1 = Connection::open_memory("backup_works");
219 connection1
220 .exec(indoc! {"
221 CREATE TABLE blobs (
222 data BLOB
223 );"})
224 .unwrap();
225 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
226 let mut write = connection1
227 .prepare("INSERT INTO blobs (data) VALUES (?);")
228 .unwrap();
229 write.bind_blob(1, blob).unwrap();
230 write.exec().unwrap();
231
232 // Backup connection1 to connection2
233 let connection2 = Connection::open_memory("backup_works_other");
234 connection1.backup_main(&connection2).unwrap();
235
236 // Delete the added blob and verify its deleted on the other side
237 let read_blobs = connection1
238 .prepare("SELECT * FROM blobs;")
239 .unwrap()
240 .rows::<Vec<u8>>()
241 .unwrap();
242 assert_eq!(read_blobs, vec![blob]);
243 }
244
245 #[test]
246 fn test_kv_store() -> anyhow::Result<()> {
247 let connection = Connection::open_memory("kv_store");
248
249 Migration::new(
250 "kv",
251 &["CREATE TABLE kv_store(
252 key TEXT PRIMARY KEY,
253 value TEXT NOT NULL
254 ) STRICT;"],
255 )
256 .run(&connection)
257 .unwrap();
258
259 let mut stmt = connection.prepare("INSERT INTO kv_store(key, value) VALUES(?, ?)")?;
260 stmt.bind_text(1, "a").unwrap();
261 stmt.bind_text(2, "b").unwrap();
262 stmt.exec().unwrap();
263 let id = connection.last_insert_id();
264
265 let res = connection
266 .prepare("SELECT key, value FROM kv_store WHERE rowid = ?")?
267 .with_bindings(id)?
268 .row::<(String, String)>()?;
269
270 assert_eq!(res, ("a".to_string(), "b".to_string()));
271
272 Ok(())
273 }
274}