1use std::{
2 ffi::{CStr, CString},
3 marker::PhantomData,
4};
5
6use anyhow::{anyhow, Result};
7use libsqlite3_sys::*;
8
9use crate::statement::Statement;
10
11pub struct Connection {
12 pub(crate) sqlite3: *mut sqlite3,
13 persistent: bool,
14 phantom: PhantomData<sqlite3>,
15}
16unsafe impl Send for Connection {}
17
18impl Connection {
19 fn open(uri: &str, persistent: bool) -> Result<Self> {
20 let mut connection = Self {
21 sqlite3: 0 as *mut _,
22 persistent,
23 phantom: PhantomData,
24 };
25
26 let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE;
27 unsafe {
28 sqlite3_open_v2(
29 CString::new(uri)?.as_ptr(),
30 &mut connection.sqlite3,
31 flags,
32 0 as *const _,
33 );
34
35 connection.last_error()?;
36 }
37
38 Ok(connection)
39 }
40
41 /// Attempts to open the database at uri. If it fails, a shared memory db will be opened
42 /// instead.
43 pub fn open_file(uri: &str) -> Self {
44 Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri))
45 }
46
47 pub fn open_memory(uri: &str) -> Self {
48 let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri);
49 Self::open(&in_memory_path, false).expect("Could not create fallback in memory db")
50 }
51
52 pub fn persistent(&self) -> bool {
53 self.persistent
54 }
55
56 pub(crate) fn last_insert_id(&self) -> i64 {
57 unsafe { sqlite3_last_insert_rowid(self.sqlite3) }
58 }
59
60 pub fn insert(&self, query: impl AsRef<str>) -> Result<i64> {
61 self.exec(query)?;
62 Ok(self.last_insert_id())
63 }
64
65 pub fn exec(&self, query: impl AsRef<str>) -> Result<()> {
66 unsafe {
67 sqlite3_exec(
68 self.sqlite3,
69 CString::new(query.as_ref())?.as_ptr(),
70 None,
71 0 as *mut _,
72 0 as *mut _,
73 );
74 self.last_error()?;
75 }
76 Ok(())
77 }
78
79 pub fn prepare<T: AsRef<str>>(&self, query: T) -> Result<Statement> {
80 Statement::prepare(&self, query)
81 }
82
83 pub fn backup_main(&self, destination: &Connection) -> Result<()> {
84 unsafe {
85 let backup = sqlite3_backup_init(
86 destination.sqlite3,
87 CString::new("main")?.as_ptr(),
88 self.sqlite3,
89 CString::new("main")?.as_ptr(),
90 );
91 sqlite3_backup_step(backup, -1);
92 sqlite3_backup_finish(backup);
93 destination.last_error()
94 }
95 }
96
97 pub(crate) fn last_error(&self) -> Result<()> {
98 const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW];
99 unsafe {
100 let code = sqlite3_errcode(self.sqlite3);
101 if NON_ERROR_CODES.contains(&code) {
102 return Ok(());
103 }
104
105 let message = sqlite3_errmsg(self.sqlite3);
106 let message = if message.is_null() {
107 None
108 } else {
109 Some(
110 String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes())
111 .into_owned(),
112 )
113 };
114
115 Err(anyhow!(
116 "Sqlite call failed with code {} and message: {:?}",
117 code as isize,
118 message
119 ))
120 }
121 }
122}
123
124impl Drop for Connection {
125 fn drop(&mut self) {
126 unsafe { sqlite3_close(self.sqlite3) };
127 }
128}
129
130#[cfg(test)]
131mod test {
132 use anyhow::Result;
133 use indoc::indoc;
134
135 use crate::connection::Connection;
136
137 #[test]
138 fn string_round_trips() -> Result<()> {
139 let connection = Connection::open_memory("string_round_trips");
140 connection
141 .exec(indoc! {"
142 CREATE TABLE text (
143 text TEXT
144 );"})
145 .unwrap();
146
147 let text = "Some test text";
148
149 connection
150 .prepare("INSERT INTO text (text) VALUES (?);")
151 .unwrap()
152 .with_bindings(text)
153 .unwrap()
154 .exec()
155 .unwrap();
156
157 assert_eq!(
158 &connection
159 .prepare("SELECT text FROM text;")
160 .unwrap()
161 .row::<String>()
162 .unwrap(),
163 text
164 );
165
166 Ok(())
167 }
168
169 #[test]
170 fn tuple_round_trips() {
171 let connection = Connection::open_memory("tuple_round_trips");
172 connection
173 .exec(indoc! {"
174 CREATE TABLE test (
175 text TEXT,
176 integer INTEGER,
177 blob BLOB
178 );"})
179 .unwrap();
180
181 let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]);
182 let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]);
183
184 let mut insert = connection
185 .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)")
186 .unwrap();
187
188 insert
189 .with_bindings(tuple1.clone())
190 .unwrap()
191 .exec()
192 .unwrap();
193 insert
194 .with_bindings(tuple2.clone())
195 .unwrap()
196 .exec()
197 .unwrap();
198
199 assert_eq!(
200 connection
201 .prepare("SELECT * FROM test")
202 .unwrap()
203 .rows::<(String, usize, Vec<u8>)>()
204 .unwrap(),
205 vec![tuple1, tuple2]
206 );
207 }
208
209 #[test]
210 fn backup_works() {
211 let connection1 = Connection::open_memory("backup_works");
212 connection1
213 .exec(indoc! {"
214 CREATE TABLE blobs (
215 data BLOB
216 );"})
217 .unwrap();
218 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
219 let mut write = connection1
220 .prepare("INSERT INTO blobs (data) VALUES (?);")
221 .unwrap();
222 write.bind_blob(1, blob).unwrap();
223 write.exec().unwrap();
224
225 // Backup connection1 to connection2
226 let connection2 = Connection::open_memory("backup_works_other");
227 connection1.backup_main(&connection2).unwrap();
228
229 // Delete the added blob and verify its deleted on the other side
230 let read_blobs = connection1
231 .prepare("SELECT * FROM blobs;")
232 .unwrap()
233 .rows::<Vec<u8>>()
234 .unwrap();
235 assert_eq!(read_blobs, vec![blob]);
236 }
237}