1use std::ffi::{c_int, CString};
2use std::marker::PhantomData;
3use std::{slice, str};
4
5use anyhow::{anyhow, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::{error_to_result, Connection};
10
11pub struct Statement<'a> {
12 raw_statement: *mut sqlite3_stmt,
13 connection: &'a Connection,
14 phantom: PhantomData<sqlite3_stmt>,
15}
16
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub enum StepResult {
19 Row,
20 Done,
21 Misuse,
22 Other(i32),
23}
24
25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
26pub enum SqlType {
27 Text,
28 Integer,
29 Blob,
30 Float,
31 Null,
32}
33
34impl<'a> Statement<'a> {
35 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
36 let mut statement = Self {
37 raw_statement: 0 as *mut _,
38 connection,
39 phantom: PhantomData,
40 };
41
42 unsafe {
43 sqlite3_prepare_v2(
44 connection.sqlite3,
45 CString::new(query.as_ref())?.as_ptr(),
46 -1,
47 &mut statement.raw_statement,
48 0 as *mut _,
49 );
50
51 connection.last_error().context("Prepare call failed.")?;
52 }
53
54 Ok(statement)
55 }
56
57 pub fn reset(&mut self) {
58 unsafe {
59 sqlite3_reset(self.raw_statement);
60 }
61 }
62
63 pub fn parameter_count(&self) -> i32 {
64 unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
65 }
66
67 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
68 // dbg!("bind blob", index);
69 let index = index as c_int;
70 let blob_pointer = blob.as_ptr() as *const _;
71 let len = blob.len() as c_int;
72 unsafe {
73 sqlite3_bind_blob(
74 self.raw_statement,
75 index,
76 blob_pointer,
77 len,
78 SQLITE_TRANSIENT(),
79 );
80 }
81 self.connection.last_error()
82 }
83
84 pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
85 let index = index as c_int;
86 let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
87
88 self.connection.last_error()?;
89 if pointer.is_null() {
90 return Ok(&[]);
91 }
92 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
93 self.connection.last_error()?;
94 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
95 }
96
97 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
98 // dbg!("bind double", index);
99 let index = index as c_int;
100
101 unsafe {
102 sqlite3_bind_double(self.raw_statement, index, double);
103 }
104 self.connection.last_error()
105 }
106
107 pub fn column_double(&self, index: i32) -> Result<f64> {
108 let index = index as c_int;
109 let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
110 self.connection.last_error()?;
111 Ok(result)
112 }
113
114 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
115 // dbg!("bind int", index);
116 let index = index as c_int;
117
118 unsafe {
119 sqlite3_bind_int(self.raw_statement, index, int);
120 };
121 self.connection.last_error()
122 }
123
124 pub fn column_int(&self, index: i32) -> Result<i32> {
125 let index = index as c_int;
126 let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
127 self.connection.last_error()?;
128 Ok(result)
129 }
130
131 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
132 // dbg!("bind int64", index);
133 let index = index as c_int;
134 unsafe {
135 sqlite3_bind_int64(self.raw_statement, index, int);
136 }
137 self.connection.last_error()
138 }
139
140 pub fn column_int64(&self, index: i32) -> Result<i64> {
141 let index = index as c_int;
142 let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
143 self.connection.last_error()?;
144 Ok(result)
145 }
146
147 pub fn bind_null(&self, index: i32) -> Result<()> {
148 // dbg!("bind null", index);
149 let index = index as c_int;
150 unsafe {
151 sqlite3_bind_null(self.raw_statement, index);
152 }
153 self.connection.last_error()
154 }
155
156 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
157 // dbg!("bind text", index, text);
158 let index = index as c_int;
159 let text_pointer = text.as_ptr() as *const _;
160 let len = text.len() as c_int;
161 unsafe {
162 sqlite3_bind_text(
163 self.raw_statement,
164 index,
165 text_pointer,
166 len,
167 SQLITE_TRANSIENT(),
168 );
169 }
170 self.connection.last_error()
171 }
172
173 pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
174 let index = index as c_int;
175 let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
176
177 self.connection.last_error()?;
178 if pointer.is_null() {
179 return Ok("");
180 }
181 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
182 self.connection.last_error()?;
183
184 let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
185 Ok(str::from_utf8(slice)?)
186 }
187
188 pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
189 debug_assert!(index > 0);
190 value.bind(self, index)
191 }
192
193 pub fn column<T: Column>(&mut self) -> Result<T> {
194 let (result, _) = T::column(self, 0)?;
195 Ok(result)
196 }
197
198 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
199 let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
200 self.connection.last_error()?;
201 match result {
202 SQLITE_INTEGER => Ok(SqlType::Integer),
203 SQLITE_FLOAT => Ok(SqlType::Float),
204 SQLITE_TEXT => Ok(SqlType::Text),
205 SQLITE_BLOB => Ok(SqlType::Blob),
206 SQLITE_NULL => Ok(SqlType::Null),
207 _ => Err(anyhow!("Column type returned was incorrect ")),
208 }
209 }
210
211 pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
212 self.bind(bindings, 1)?;
213 Ok(self)
214 }
215
216 fn step(&mut self) -> Result<StepResult> {
217 unsafe {
218 match sqlite3_step(self.raw_statement) {
219 SQLITE_ROW => Ok(StepResult::Row),
220 SQLITE_DONE => Ok(StepResult::Done),
221 SQLITE_MISUSE => Ok(StepResult::Misuse),
222 other => self
223 .connection
224 .last_error()
225 .map(|_| StepResult::Other(other)),
226 }
227 }
228 }
229
230 pub fn insert(&mut self) -> Result<i64> {
231 self.exec()?;
232 Ok(self.connection.last_insert_id())
233 }
234
235 pub fn exec(&mut self) -> Result<()> {
236 fn logic(this: &mut Statement) -> Result<()> {
237 while this.step()? == StepResult::Row {}
238 Ok(())
239 }
240 let result = logic(self);
241 self.reset();
242 result
243 }
244
245 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
246 fn logic<R>(
247 this: &mut Statement,
248 mut callback: impl FnMut(&mut Statement) -> Result<R>,
249 ) -> Result<Vec<R>> {
250 let mut mapped_rows = Vec::new();
251 while this.step()? == StepResult::Row {
252 mapped_rows.push(callback(this)?);
253 }
254 Ok(mapped_rows)
255 }
256
257 let result = logic(self, callback);
258 self.reset();
259 result
260 }
261
262 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
263 self.map(|s| s.column::<R>())
264 }
265
266 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
267 fn logic<R>(
268 this: &mut Statement,
269 callback: impl FnOnce(&mut Statement) -> Result<R>,
270 ) -> Result<R> {
271 if this.step()? != StepResult::Row {
272 return Err(anyhow!(
273 "Single(Map) called with query that returns no rows."
274 ));
275 }
276 callback(this)
277 }
278 let result = logic(self, callback);
279 self.reset();
280 result
281 }
282
283 pub fn row<R: Column>(&mut self) -> Result<R> {
284 self.single(|this| this.column::<R>())
285 }
286
287 pub fn maybe<R>(
288 &mut self,
289 callback: impl FnOnce(&mut Statement) -> Result<R>,
290 ) -> Result<Option<R>> {
291 fn logic<R>(
292 this: &mut Statement,
293 callback: impl FnOnce(&mut Statement) -> Result<R>,
294 ) -> Result<Option<R>> {
295 if this.step()? != StepResult::Row {
296 return Ok(None);
297 }
298 callback(this).map(|r| Some(r))
299 }
300 let result = logic(self, callback);
301 self.reset();
302 result
303 }
304
305 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
306 self.maybe(|this| this.column::<R>())
307 }
308}
309
310impl<'a> Drop for Statement<'a> {
311 fn drop(&mut self) {
312 unsafe {
313 let error = sqlite3_finalize(self.raw_statement);
314 error_to_result(error).expect("failed error");
315 };
316 }
317}
318
319#[cfg(test)]
320mod test {
321 use indoc::indoc;
322
323 use crate::{connection::Connection, statement::StepResult};
324
325 #[test]
326 fn blob_round_trips() {
327 let connection1 = Connection::open_memory("blob_round_trips");
328 connection1
329 .exec(indoc! {"
330 CREATE TABLE blobs (
331 data BLOB
332 );"})
333 .unwrap();
334
335 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
336
337 let mut write = connection1
338 .prepare("INSERT INTO blobs (data) VALUES (?);")
339 .unwrap();
340 write.bind_blob(1, blob).unwrap();
341 assert_eq!(write.step().unwrap(), StepResult::Done);
342
343 // Read the blob from the
344 let connection2 = Connection::open_memory("blob_round_trips");
345 let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
346 assert_eq!(read.step().unwrap(), StepResult::Row);
347 assert_eq!(read.column_blob(0).unwrap(), blob);
348 assert_eq!(read.step().unwrap(), StepResult::Done);
349
350 // Delete the added blob and verify its deleted on the other side
351 connection2.exec("DELETE FROM blobs;").unwrap();
352 let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
353 assert_eq!(read.step().unwrap(), StepResult::Done);
354 }
355}