1use std::ffi::{c_int, CStr, CString};
2use std::marker::PhantomData;
3use std::{ptr, slice, str};
4
5use anyhow::{anyhow, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 raw_statements: Vec<*mut sqlite3_stmt>,
13 current_statement: usize,
14 connection: &'a Connection,
15 phantom: PhantomData<sqlite3_stmt>,
16}
17
18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub enum StepResult {
20 Row,
21 Done,
22 Misuse,
23 Other(i32),
24}
25
26#[derive(Clone, Copy, PartialEq, Eq, Debug)]
27pub enum SqlType {
28 Text,
29 Integer,
30 Blob,
31 Float,
32 Null,
33}
34
35impl<'a> Statement<'a> {
36 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
37 let mut statement = Self {
38 raw_statements: Default::default(),
39 current_statement: 0,
40 connection,
41 phantom: PhantomData,
42 };
43
44 unsafe {
45 let sql = CString::new(query.as_ref())?;
46 let mut remaining_sql = sql.as_c_str();
47 while {
48 let remaining_sql_str = remaining_sql.to_str()?.trim();
49 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
50 } {
51 let mut raw_statement = 0 as *mut sqlite3_stmt;
52 let mut remaining_sql_ptr = ptr::null();
53 sqlite3_prepare_v2(
54 connection.sqlite3,
55 remaining_sql.as_ptr(),
56 -1,
57 &mut raw_statement,
58 &mut remaining_sql_ptr,
59 );
60 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
61 statement.raw_statements.push(raw_statement);
62 }
63
64 connection
65 .last_error()
66 .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?;
67 }
68
69 Ok(statement)
70 }
71
72 fn current_statement(&self) -> *mut sqlite3_stmt {
73 *self.raw_statements.get(self.current_statement).unwrap()
74 }
75
76 pub fn reset(&mut self) {
77 unsafe {
78 for raw_statement in self.raw_statements.iter() {
79 sqlite3_reset(*raw_statement);
80 }
81 }
82 self.current_statement = 0;
83 }
84
85 pub fn parameter_count(&self) -> i32 {
86 unsafe {
87 self.raw_statements
88 .iter()
89 .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
90 .max()
91 .unwrap_or(0)
92 }
93 }
94
95 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
96 let index = index as c_int;
97 let blob_pointer = blob.as_ptr() as *const _;
98 let len = blob.len() as c_int;
99 unsafe {
100 for raw_statement in self.raw_statements.iter() {
101 sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
102 }
103 }
104 self.connection.last_error()
105 }
106
107 pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
108 let index = index as c_int;
109 let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
110
111 self.connection.last_error()?;
112 if pointer.is_null() {
113 return Ok(&[]);
114 }
115 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
116 self.connection.last_error()?;
117 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
118 }
119
120 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
121 let index = index as c_int;
122
123 unsafe {
124 for raw_statement in self.raw_statements.iter() {
125 sqlite3_bind_double(*raw_statement, index, double);
126 }
127 }
128 self.connection.last_error()
129 }
130
131 pub fn column_double(&self, index: i32) -> Result<f64> {
132 let index = index as c_int;
133 let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
134 self.connection.last_error()?;
135 Ok(result)
136 }
137
138 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
139 let index = index as c_int;
140
141 unsafe {
142 for raw_statement in self.raw_statements.iter() {
143 sqlite3_bind_int(*raw_statement, index, int);
144 }
145 };
146 self.connection.last_error()
147 }
148
149 pub fn column_int(&self, index: i32) -> Result<i32> {
150 let index = index as c_int;
151 let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
152 self.connection.last_error()?;
153 Ok(result)
154 }
155
156 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
157 let index = index as c_int;
158 unsafe {
159 for raw_statement in self.raw_statements.iter() {
160 sqlite3_bind_int64(*raw_statement, index, int);
161 }
162 }
163 self.connection.last_error()
164 }
165
166 pub fn column_int64(&self, index: i32) -> Result<i64> {
167 let index = index as c_int;
168 let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
169 self.connection.last_error()?;
170 Ok(result)
171 }
172
173 pub fn bind_null(&self, index: i32) -> Result<()> {
174 let index = index as c_int;
175 unsafe {
176 for raw_statement in self.raw_statements.iter() {
177 sqlite3_bind_null(*raw_statement, index);
178 }
179 }
180 self.connection.last_error()
181 }
182
183 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
184 let index = index as c_int;
185 let text_pointer = text.as_ptr() as *const _;
186 let len = text.len() as c_int;
187 unsafe {
188 for raw_statement in self.raw_statements.iter() {
189 sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
190 }
191 }
192 self.connection.last_error()
193 }
194
195 pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
196 let index = index as c_int;
197 let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
198
199 self.connection.last_error()?;
200 if pointer.is_null() {
201 return Ok("");
202 }
203 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
204 self.connection.last_error()?;
205
206 let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
207 Ok(str::from_utf8(slice)?)
208 }
209
210 pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
211 debug_assert!(index > 0);
212 value.bind(self, index)
213 }
214
215 pub fn column<T: Column>(&mut self) -> Result<T> {
216 let (result, _) = T::column(self, 0)?;
217 Ok(result)
218 }
219
220 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
221 let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
222 self.connection.last_error()?;
223 match result {
224 SQLITE_INTEGER => Ok(SqlType::Integer),
225 SQLITE_FLOAT => Ok(SqlType::Float),
226 SQLITE_TEXT => Ok(SqlType::Text),
227 SQLITE_BLOB => Ok(SqlType::Blob),
228 SQLITE_NULL => Ok(SqlType::Null),
229 _ => Err(anyhow!("Column type returned was incorrect ")),
230 }
231 }
232
233 pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
234 self.bind(bindings, 1)?;
235 Ok(self)
236 }
237
238 fn step(&mut self) -> Result<StepResult> {
239 unsafe {
240 match sqlite3_step(self.current_statement()) {
241 SQLITE_ROW => Ok(StepResult::Row),
242 SQLITE_DONE => {
243 if self.current_statement >= self.raw_statements.len() - 1 {
244 Ok(StepResult::Done)
245 } else {
246 self.current_statement += 1;
247 self.step()
248 }
249 }
250 SQLITE_MISUSE => Ok(StepResult::Misuse),
251 other => self
252 .connection
253 .last_error()
254 .map(|_| StepResult::Other(other)),
255 }
256 }
257 }
258
259 pub fn insert(&mut self) -> Result<i64> {
260 self.exec()?;
261 Ok(self.connection.last_insert_id())
262 }
263
264 pub fn exec(&mut self) -> Result<()> {
265 fn logic(this: &mut Statement) -> Result<()> {
266 while this.step()? == StepResult::Row {}
267 Ok(())
268 }
269 let result = logic(self);
270 self.reset();
271 result
272 }
273
274 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
275 fn logic<R>(
276 this: &mut Statement,
277 mut callback: impl FnMut(&mut Statement) -> Result<R>,
278 ) -> Result<Vec<R>> {
279 let mut mapped_rows = Vec::new();
280 while this.step()? == StepResult::Row {
281 mapped_rows.push(callback(this)?);
282 }
283 Ok(mapped_rows)
284 }
285
286 let result = logic(self, callback);
287 self.reset();
288 result
289 }
290
291 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
292 self.map(|s| s.column::<R>())
293 }
294
295 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
296 fn logic<R>(
297 this: &mut Statement,
298 callback: impl FnOnce(&mut Statement) -> Result<R>,
299 ) -> Result<R> {
300 if this.step()? != StepResult::Row {
301 return Err(anyhow!(
302 "Single(Map) called with query that returns no rows."
303 ));
304 }
305 callback(this)
306 }
307 let result = logic(self, callback);
308 self.reset();
309 result
310 }
311
312 pub fn row<R: Column>(&mut self) -> Result<R> {
313 self.single(|this| this.column::<R>())
314 }
315
316 pub fn maybe<R>(
317 &mut self,
318 callback: impl FnOnce(&mut Statement) -> Result<R>,
319 ) -> Result<Option<R>> {
320 fn logic<R>(
321 this: &mut Statement,
322 callback: impl FnOnce(&mut Statement) -> Result<R>,
323 ) -> Result<Option<R>> {
324 if this.step()? != StepResult::Row {
325 return Ok(None);
326 }
327 callback(this).map(|r| Some(r))
328 }
329 let result = logic(self, callback);
330 self.reset();
331 result
332 }
333
334 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
335 self.maybe(|this| this.column::<R>())
336 }
337}
338
339impl<'a> Drop for Statement<'a> {
340 fn drop(&mut self) {
341 unsafe {
342 for raw_statement in self.raw_statements.iter() {
343 sqlite3_finalize(*raw_statement);
344 }
345 }
346 }
347}
348
349#[cfg(test)]
350mod test {
351 use indoc::indoc;
352
353 use crate::{
354 connection::Connection,
355 statement::{Statement, StepResult},
356 };
357
358 #[test]
359 fn blob_round_trips() {
360 let connection1 = Connection::open_memory("blob_round_trips");
361 connection1
362 .exec(indoc! {"
363 CREATE TABLE blobs (
364 data BLOB
365 )"})
366 .unwrap()()
367 .unwrap();
368
369 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
370
371 let mut write =
372 Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
373 write.bind_blob(1, blob).unwrap();
374 assert_eq!(write.step().unwrap(), StepResult::Done);
375
376 // Read the blob from the
377 let connection2 = Connection::open_memory("blob_round_trips");
378 let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
379 assert_eq!(read.step().unwrap(), StepResult::Row);
380 assert_eq!(read.column_blob(0).unwrap(), blob);
381 assert_eq!(read.step().unwrap(), StepResult::Done);
382
383 // Delete the added blob and verify its deleted on the other side
384 connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
385 let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
386 assert_eq!(read.step().unwrap(), StepResult::Done);
387 }
388
389 #[test]
390 pub fn maybe_returns_options() {
391 let connection = Connection::open_memory("maybe_returns_options");
392 connection
393 .exec(indoc! {"
394 CREATE TABLE texts (
395 text TEXT
396 )"})
397 .unwrap()()
398 .unwrap();
399
400 assert!(connection
401 .select_row::<String>("SELECT text FROM texts")
402 .unwrap()()
403 .unwrap()
404 .is_none());
405
406 let text_to_insert = "This is a test";
407
408 connection
409 .exec_bound("INSERT INTO texts VALUES (?)")
410 .unwrap()(text_to_insert)
411 .unwrap();
412
413 assert_eq!(
414 connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
415 Some(text_to_insert.to_string())
416 );
417 }
418}