1use std::ffi::{c_int, CStr, CString};
2use std::marker::PhantomData;
3use std::{ptr, slice, str};
4
5use anyhow::{anyhow, bail, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 /// vector of pointers to the raw SQLite statement objects.
13 /// it holds the actual prepared statements that will be executed.
14 pub raw_statements: Vec<*mut sqlite3_stmt>,
15 /// Index of the current statement being executed from the `raw_statements` vector.
16 current_statement: usize,
17 /// A reference to the database connection.
18 /// This is used to execute the statements and check for errors.
19 connection: &'a Connection,
20 ///Indicates that the `Statement` struct is tied to the lifetime of the SQLite statement
21 phantom: PhantomData<sqlite3_stmt>,
22}
23
24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
25pub enum StepResult {
26 Row,
27 Done,
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Debug)]
31pub enum SqlType {
32 Text,
33 Integer,
34 Blob,
35 Float,
36 Null,
37}
38
39impl<'a> Statement<'a> {
40 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
41 let mut statement = Self {
42 raw_statements: Default::default(),
43 current_statement: 0,
44 connection,
45 phantom: PhantomData,
46 };
47 unsafe {
48 let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
49 let mut remaining_sql = sql.as_c_str();
50 while {
51 let remaining_sql_str = remaining_sql
52 .to_str()
53 .context("Parsing remaining sql")?
54 .trim();
55 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
56 } {
57 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
58 let mut remaining_sql_ptr = ptr::null();
59 sqlite3_prepare_v2(
60 connection.sqlite3,
61 remaining_sql.as_ptr(),
62 -1,
63 &mut raw_statement,
64 &mut remaining_sql_ptr,
65 );
66
67 connection.last_error().with_context(|| {
68 format!("Prepare call failed for query:\n{}", query.as_ref())
69 })?;
70
71 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
72 statement.raw_statements.push(raw_statement);
73
74 if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
75 let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
76
77 bail!(
78 "Write statement prepared with connection that is not write capable. SQL:\n{} ",
79 sql.to_str()?)
80 }
81 }
82 }
83
84 Ok(statement)
85 }
86
87 fn current_statement(&self) -> *mut sqlite3_stmt {
88 *self.raw_statements.get(self.current_statement).unwrap()
89 }
90
91 pub fn reset(&mut self) {
92 unsafe {
93 for raw_statement in self.raw_statements.iter() {
94 sqlite3_reset(*raw_statement);
95 }
96 }
97 self.current_statement = 0;
98 }
99
100 pub fn parameter_count(&self) -> i32 {
101 unsafe {
102 self.raw_statements
103 .iter()
104 .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
105 .max()
106 .unwrap_or(0)
107 }
108 }
109
110 fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
111 let mut any_succeed = false;
112 unsafe {
113 for raw_statement in self.raw_statements.iter() {
114 if index <= sqlite3_bind_parameter_count(*raw_statement) {
115 bind(raw_statement);
116 self.connection
117 .last_error()
118 .with_context(|| format!("Failed to bind value at index {index}"))?;
119 any_succeed = true;
120 } else {
121 continue;
122 }
123 }
124 }
125 if any_succeed {
126 Ok(())
127 } else {
128 Err(anyhow!("Failed to bind parameters"))
129 }
130 }
131
132 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
133 let index = index as c_int;
134 let blob_pointer = blob.as_ptr() as *const _;
135 let len = blob.len() as c_int;
136
137 self.bind_index_with(index, |raw_statement| unsafe {
138 sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
139 })
140 }
141
142 pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
143 let index = index as c_int;
144 let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
145
146 self.connection
147 .last_error()
148 .with_context(|| format!("Failed to read blob at index {index}"))?;
149 if pointer.is_null() {
150 return Ok(&[]);
151 }
152 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
153 self.connection
154 .last_error()
155 .with_context(|| format!("Failed to read length of blob at index {index}"))?;
156
157 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
158 }
159
160 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
161 let index = index as c_int;
162
163 self.bind_index_with(index, |raw_statement| unsafe {
164 sqlite3_bind_double(*raw_statement, index, double);
165 })
166 }
167
168 pub fn column_double(&self, index: i32) -> Result<f64> {
169 let index = index as c_int;
170 let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
171 self.connection
172 .last_error()
173 .with_context(|| format!("Failed to read double at index {index}"))?;
174 Ok(result)
175 }
176
177 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
178 let index = index as c_int;
179 self.bind_index_with(index, |raw_statement| unsafe {
180 sqlite3_bind_int(*raw_statement, index, int);
181 })
182 }
183
184 pub fn column_int(&self, index: i32) -> Result<i32> {
185 let index = index as c_int;
186 let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
187 self.connection
188 .last_error()
189 .with_context(|| format!("Failed to read int at index {index}"))?;
190 Ok(result)
191 }
192
193 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
194 let index = index as c_int;
195 self.bind_index_with(index, |raw_statement| unsafe {
196 sqlite3_bind_int64(*raw_statement, index, int);
197 })
198 }
199
200 pub fn column_int64(&self, index: i32) -> Result<i64> {
201 let index = index as c_int;
202 let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
203 self.connection
204 .last_error()
205 .with_context(|| format!("Failed to read i64 at index {index}"))?;
206 Ok(result)
207 }
208
209 pub fn bind_null(&self, index: i32) -> Result<()> {
210 let index = index as c_int;
211 self.bind_index_with(index, |raw_statement| unsafe {
212 sqlite3_bind_null(*raw_statement, index);
213 })
214 }
215
216 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
217 let index = index as c_int;
218 let text_pointer = text.as_ptr() as *const _;
219 let len = text.len() as c_int;
220
221 self.bind_index_with(index, |raw_statement| unsafe {
222 sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
223 })
224 }
225
226 pub fn column_text(&mut self, index: i32) -> Result<&str> {
227 let index = index as c_int;
228 let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
229
230 self.connection
231 .last_error()
232 .with_context(|| format!("Failed to read text from column {index}"))?;
233 if pointer.is_null() {
234 return Ok("");
235 }
236 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
237 self.connection
238 .last_error()
239 .with_context(|| format!("Failed to read text length at {index}"))?;
240
241 let slice = unsafe { slice::from_raw_parts(pointer, len) };
242 Ok(str::from_utf8(slice)?)
243 }
244
245 pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
246 debug_assert!(index > 0);
247 value.bind(self, index)
248 }
249
250 pub fn column<T: Column>(&mut self) -> Result<T> {
251 Ok(T::column(self, 0)?.0)
252 }
253
254 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
255 let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
256 self.connection.last_error()?;
257 match result {
258 SQLITE_INTEGER => Ok(SqlType::Integer),
259 SQLITE_FLOAT => Ok(SqlType::Float),
260 SQLITE_TEXT => Ok(SqlType::Text),
261 SQLITE_BLOB => Ok(SqlType::Blob),
262 SQLITE_NULL => Ok(SqlType::Null),
263 _ => Err(anyhow!("Column type returned was incorrect ")),
264 }
265 }
266
267 pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
268 self.bind(bindings, 1)?;
269 Ok(self)
270 }
271
272 fn step(&mut self) -> Result<StepResult> {
273 unsafe {
274 match sqlite3_step(self.current_statement()) {
275 SQLITE_ROW => Ok(StepResult::Row),
276 SQLITE_DONE => {
277 if self.current_statement >= self.raw_statements.len() - 1 {
278 Ok(StepResult::Done)
279 } else {
280 self.current_statement += 1;
281 self.step()
282 }
283 }
284 SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
285 _other_error => {
286 self.connection.last_error()?;
287 unreachable!("Step returned error code and last error failed to catch it");
288 }
289 }
290 }
291 }
292
293 pub fn exec(&mut self) -> Result<()> {
294 fn logic(this: &mut Statement) -> Result<()> {
295 while this.step()? == StepResult::Row {}
296 Ok(())
297 }
298 let result = logic(self);
299 self.reset();
300 result
301 }
302
303 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
304 fn logic<R>(
305 this: &mut Statement,
306 mut callback: impl FnMut(&mut Statement) -> Result<R>,
307 ) -> Result<Vec<R>> {
308 let mut mapped_rows = Vec::new();
309 while this.step()? == StepResult::Row {
310 mapped_rows.push(callback(this)?);
311 }
312 Ok(mapped_rows)
313 }
314
315 let result = logic(self, callback);
316 self.reset();
317 result
318 }
319
320 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
321 self.map(|s| s.column::<R>())
322 }
323
324 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
325 fn logic<R>(
326 this: &mut Statement,
327 callback: impl FnOnce(&mut Statement) -> Result<R>,
328 ) -> Result<R> {
329 println!("{:?}", std::any::type_name::<R>());
330 if this.step()? != StepResult::Row {
331 return Err(anyhow!("single called with query that returns no rows."));
332 }
333 let result = callback(this)?;
334
335 if this.step()? != StepResult::Done {
336 return Err(anyhow!(
337 "single called with a query that returns more than one row."
338 ));
339 }
340
341 Ok(result)
342 }
343 let result = logic(self, callback);
344 self.reset();
345 result
346 }
347
348 pub fn row<R: Column>(&mut self) -> Result<R> {
349 self.single(|this| this.column::<R>())
350 }
351
352 pub fn maybe<R>(
353 &mut self,
354 callback: impl FnOnce(&mut Statement) -> Result<R>,
355 ) -> Result<Option<R>> {
356 fn logic<R>(
357 this: &mut Statement,
358 callback: impl FnOnce(&mut Statement) -> Result<R>,
359 ) -> Result<Option<R>> {
360 if this.step().context("Failed on step call")? != StepResult::Row {
361 return Ok(None);
362 }
363
364 let result = callback(this)
365 .map(|r| Some(r))
366 .context("Failed to parse row result")?;
367
368 if this.step().context("Second step call")? != StepResult::Done {
369 return Err(anyhow!(
370 "maybe called with a query that returns more than one row."
371 ));
372 }
373
374 Ok(result)
375 }
376 let result = logic(self, callback);
377 self.reset();
378 result
379 }
380
381 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
382 self.maybe(|this| this.column::<R>())
383 }
384}
385
386impl Drop for Statement<'_> {
387 fn drop(&mut self) {
388 unsafe {
389 for raw_statement in self.raw_statements.iter() {
390 sqlite3_finalize(*raw_statement);
391 }
392 }
393 }
394}
395
396#[cfg(test)]
397mod test {
398 use indoc::indoc;
399
400 use crate::{
401 connection::Connection,
402 statement::{Statement, StepResult},
403 };
404
405 #[test]
406 fn binding_multiple_statements_with_parameter_gaps() {
407 let connection =
408 Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
409
410 connection
411 .exec(indoc! {"
412 CREATE TABLE test (
413 col INTEGER
414 )"})
415 .unwrap()()
416 .unwrap();
417
418 let statement = Statement::prepare(
419 &connection,
420 indoc! {"
421 INSERT INTO test(col) VALUES (?3);
422 SELECT * FROM test WHERE col = ?1"},
423 )
424 .unwrap();
425
426 statement
427 .bind_int(1, 1)
428 .expect("Could not bind parameter to first index");
429 statement
430 .bind_int(2, 2)
431 .expect("Could not bind parameter to second index");
432 statement
433 .bind_int(3, 3)
434 .expect("Could not bind parameter to third index");
435 }
436
437 #[test]
438 fn blob_round_trips() {
439 let connection1 = Connection::open_memory(Some("blob_round_trips"));
440 connection1
441 .exec(indoc! {"
442 CREATE TABLE blobs (
443 data BLOB
444 )"})
445 .unwrap()()
446 .unwrap();
447
448 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
449
450 let mut write =
451 Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
452 write.bind_blob(1, blob).unwrap();
453 assert_eq!(write.step().unwrap(), StepResult::Done);
454
455 // Read the blob from the
456 let connection2 = Connection::open_memory(Some("blob_round_trips"));
457 let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
458 assert_eq!(read.step().unwrap(), StepResult::Row);
459 assert_eq!(read.column_blob(0).unwrap(), blob);
460 assert_eq!(read.step().unwrap(), StepResult::Done);
461
462 // Delete the added blob and verify its deleted on the other side
463 connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
464 let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
465 assert_eq!(read.step().unwrap(), StepResult::Done);
466 }
467
468 #[test]
469 pub fn maybe_returns_options() {
470 let connection = Connection::open_memory(Some("maybe_returns_options"));
471 connection
472 .exec(indoc! {"
473 CREATE TABLE texts (
474 text TEXT
475 )"})
476 .unwrap()()
477 .unwrap();
478
479 assert!(connection
480 .select_row::<String>("SELECT text FROM texts")
481 .unwrap()()
482 .unwrap()
483 .is_none());
484
485 let text_to_insert = "This is a test";
486
487 connection
488 .exec_bound("INSERT INTO texts VALUES (?)")
489 .unwrap()(text_to_insert)
490 .unwrap();
491
492 assert_eq!(
493 connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
494 Some(text_to_insert.to_string())
495 );
496 }
497}