1use std::ffi::{c_int, CString};
2use std::marker::PhantomData;
3use std::{slice, str};
4
5use anyhow::{anyhow, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 raw_statement: *mut sqlite3_stmt,
13 connection: &'a Connection,
14 phantom: PhantomData<sqlite3_stmt>,
15}
16
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub enum StepResult {
19 Row,
20 Done,
21 Misuse,
22 Other(i32),
23}
24
25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
26pub enum SqlType {
27 Text,
28 Integer,
29 Blob,
30 Float,
31 Null,
32}
33
34impl<'a> Statement<'a> {
35 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
36 let mut statement = Self {
37 raw_statement: 0 as *mut _,
38 connection,
39 phantom: PhantomData,
40 };
41
42 unsafe {
43 sqlite3_prepare_v2(
44 connection.sqlite3,
45 CString::new(query.as_ref())?.as_ptr(),
46 -1,
47 &mut statement.raw_statement,
48 0 as *mut _,
49 );
50
51 connection.last_error().context("Prepare call failed.")?;
52 }
53
54 Ok(statement)
55 }
56
57 pub fn reset(&mut self) {
58 unsafe {
59 sqlite3_reset(self.raw_statement);
60 }
61 }
62
63 pub fn parameter_count(&self) -> i32 {
64 unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
65 }
66
67 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
68 let index = index as c_int;
69 let blob_pointer = blob.as_ptr() as *const _;
70 let len = blob.len() as c_int;
71 unsafe {
72 sqlite3_bind_blob(
73 self.raw_statement,
74 index,
75 blob_pointer,
76 len,
77 SQLITE_TRANSIENT(),
78 );
79 }
80 self.connection.last_error()
81 }
82
83 pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
84 let index = index as c_int;
85 let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
86
87 self.connection.last_error()?;
88 if pointer.is_null() {
89 return Ok(&[]);
90 }
91 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
92 self.connection.last_error()?;
93 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
94 }
95
96 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
97 let index = index as c_int;
98
99 unsafe {
100 sqlite3_bind_double(self.raw_statement, index, double);
101 }
102 self.connection.last_error()
103 }
104
105 pub fn column_double(&self, index: i32) -> Result<f64> {
106 let index = index as c_int;
107 let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
108 self.connection.last_error()?;
109 Ok(result)
110 }
111
112 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
113 let index = index as c_int;
114
115 unsafe {
116 sqlite3_bind_int(self.raw_statement, index, int);
117 }
118 self.connection.last_error()
119 }
120
121 pub fn column_int(&self, index: i32) -> Result<i32> {
122 let index = index as c_int;
123 let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
124 self.connection.last_error()?;
125 Ok(result)
126 }
127
128 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
129 let index = index as c_int;
130 unsafe {
131 sqlite3_bind_int64(self.raw_statement, index, int);
132 }
133 self.connection.last_error()
134 }
135
136 pub fn column_int64(&self, index: i32) -> Result<i64> {
137 let index = index as c_int;
138 let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
139 self.connection.last_error()?;
140 Ok(result)
141 }
142
143 pub fn bind_null(&self, index: i32) -> Result<()> {
144 let index = index as c_int;
145 unsafe {
146 sqlite3_bind_null(self.raw_statement, index);
147 }
148 self.connection.last_error()
149 }
150
151 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
152 let index = index as c_int;
153 let text_pointer = text.as_ptr() as *const _;
154 let len = text.len() as c_int;
155 unsafe {
156 sqlite3_bind_blob(
157 self.raw_statement,
158 index,
159 text_pointer,
160 len,
161 SQLITE_TRANSIENT(),
162 );
163 }
164 self.connection.last_error()
165 }
166
167 pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
168 let index = index as c_int;
169 let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
170
171 self.connection.last_error()?;
172 if pointer.is_null() {
173 return Ok("");
174 }
175 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
176 self.connection.last_error()?;
177
178 let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
179 Ok(str::from_utf8(slice)?)
180 }
181
182 pub fn bind_value<T: Bind>(&self, value: T, idx: i32) -> Result<()> {
183 debug_assert!(idx > 0);
184 value.bind(self, idx)?;
185 Ok(())
186 }
187
188 pub fn column<T: Column>(&mut self) -> Result<T> {
189 let (result, _) = T::column(self, 0)?;
190 Ok(result)
191 }
192
193 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
194 let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
195 self.connection.last_error()?;
196 match result {
197 SQLITE_INTEGER => Ok(SqlType::Integer),
198 SQLITE_FLOAT => Ok(SqlType::Float),
199 SQLITE_TEXT => Ok(SqlType::Text),
200 SQLITE_BLOB => Ok(SqlType::Blob),
201 SQLITE_NULL => Ok(SqlType::Null),
202 _ => Err(anyhow!("Column type returned was incorrect ")),
203 }
204 }
205
206 pub fn bind(&mut self, bindings: impl Bind) -> Result<&mut Self> {
207 self.bind_value(bindings, 1)?;
208 Ok(self)
209 }
210
211 fn step(&mut self) -> Result<StepResult> {
212 unsafe {
213 match sqlite3_step(self.raw_statement) {
214 SQLITE_ROW => Ok(StepResult::Row),
215 SQLITE_DONE => Ok(StepResult::Done),
216 SQLITE_MISUSE => Ok(StepResult::Misuse),
217 other => self
218 .connection
219 .last_error()
220 .map(|_| StepResult::Other(other)),
221 }
222 }
223 }
224
225 pub fn insert(&mut self) -> Result<i64> {
226 self.exec()?;
227 Ok(self.connection.last_insert_id())
228 }
229
230 pub fn exec(&mut self) -> Result<()> {
231 fn logic(this: &mut Statement) -> Result<()> {
232 while this.step()? == StepResult::Row {}
233 Ok(())
234 }
235 let result = logic(self);
236 self.reset();
237 result
238 }
239
240 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
241 fn logic<R>(
242 this: &mut Statement,
243 mut callback: impl FnMut(&mut Statement) -> Result<R>,
244 ) -> Result<Vec<R>> {
245 let mut mapped_rows = Vec::new();
246 while this.step()? == StepResult::Row {
247 mapped_rows.push(callback(this)?);
248 }
249 Ok(mapped_rows)
250 }
251
252 let result = logic(self, callback);
253 self.reset();
254 result
255 }
256
257 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
258 self.map(|s| s.column::<R>())
259 }
260
261 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
262 fn logic<R>(
263 this: &mut Statement,
264 callback: impl FnOnce(&mut Statement) -> Result<R>,
265 ) -> Result<R> {
266 if this.step()? != StepResult::Row {
267 return Err(anyhow!(
268 "Single(Map) called with query that returns no rows."
269 ));
270 }
271 callback(this)
272 }
273 let result = logic(self, callback);
274 self.reset();
275 result
276 }
277
278 pub fn row<R: Column>(&mut self) -> Result<R> {
279 self.single(|this| this.column::<R>())
280 }
281
282 pub fn maybe<R>(
283 &mut self,
284 callback: impl FnOnce(&mut Statement) -> Result<R>,
285 ) -> Result<Option<R>> {
286 fn logic<R>(
287 this: &mut Statement,
288 callback: impl FnOnce(&mut Statement) -> Result<R>,
289 ) -> Result<Option<R>> {
290 if this.step()? != StepResult::Row {
291 return Ok(None);
292 }
293 callback(this).map(|r| Some(r))
294 }
295 let result = logic(self, callback);
296 self.reset();
297 result
298 }
299
300 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
301 self.maybe(|this| this.column::<R>())
302 }
303}
304
305impl<'a> Drop for Statement<'a> {
306 fn drop(&mut self) {
307 unsafe {
308 sqlite3_finalize(self.raw_statement);
309 self.connection
310 .last_error()
311 .expect("sqlite3 finalize failed for statement :(");
312 };
313 }
314}
315
316#[cfg(test)]
317mod test {
318 use indoc::indoc;
319
320 use crate::{connection::Connection, statement::StepResult};
321
322 #[test]
323 fn blob_round_trips() {
324 let connection1 = Connection::open_memory("blob_round_trips");
325 connection1
326 .exec(indoc! {"
327 CREATE TABLE blobs (
328 data BLOB
329 );"})
330 .unwrap();
331
332 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
333
334 let mut write = connection1
335 .prepare("INSERT INTO blobs (data) VALUES (?);")
336 .unwrap();
337 write.bind_blob(1, blob).unwrap();
338 assert_eq!(write.step().unwrap(), StepResult::Done);
339
340 // Read the blob from the
341 let connection2 = Connection::open_memory("blob_round_trips");
342 let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
343 assert_eq!(read.step().unwrap(), StepResult::Row);
344 assert_eq!(read.column_blob(0).unwrap(), blob);
345 assert_eq!(read.step().unwrap(), StepResult::Done);
346
347 // Delete the added blob and verify its deleted on the other side
348 connection2.exec("DELETE FROM blobs;").unwrap();
349 let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
350 assert_eq!(read.step().unwrap(), StepResult::Done);
351 }
352}