1use std::ffi::{CStr, CString, c_int};
2use std::marker::PhantomData;
3use std::{ptr, slice, str};
4
5use anyhow::{Context as _, Result, bail};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 /// vector of pointers to the raw SQLite statement objects.
13 /// it holds the actual prepared statements that will be executed.
14 pub raw_statements: Vec<*mut sqlite3_stmt>,
15 /// Index of the current statement being executed from the `raw_statements` vector.
16 current_statement: usize,
17 /// A reference to the database connection.
18 /// This is used to execute the statements and check for errors.
19 connection: &'a Connection,
20 ///Indicates that the `Statement` struct is tied to the lifetime of the SQLite statement
21 phantom: PhantomData<sqlite3_stmt>,
22}
23
24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
25pub enum StepResult {
26 Row,
27 Done,
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Debug)]
31pub enum SqlType {
32 Text,
33 Integer,
34 Blob,
35 Float,
36 Null,
37}
38
39impl<'a> Statement<'a> {
40 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
41 let mut statement = Self {
42 raw_statements: Default::default(),
43 current_statement: 0,
44 connection,
45 phantom: PhantomData,
46 };
47 let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
48 let mut remaining_sql = sql.as_c_str();
49 while {
50 let remaining_sql_str = remaining_sql
51 .to_str()
52 .context("Parsing remaining sql")?
53 .trim();
54 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
55 } {
56 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
57 let mut remaining_sql_ptr = ptr::null();
58 unsafe {
59 sqlite3_prepare_v2(
60 connection.sqlite3,
61 remaining_sql.as_ptr(),
62 -1,
63 &mut raw_statement,
64 &mut remaining_sql_ptr,
65 )
66 };
67
68 connection
69 .last_error()
70 .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?;
71
72 remaining_sql = unsafe { CStr::from_ptr(remaining_sql_ptr) };
73 statement.raw_statements.push(raw_statement);
74
75 if !connection.can_write() && unsafe { sqlite3_stmt_readonly(raw_statement) == 0 } {
76 let sql = unsafe { CStr::from_ptr(sqlite3_sql(raw_statement)) };
77
78 bail!(
79 "Write statement prepared with connection that is not write capable. SQL:\n{} ",
80 sql.to_str()?
81 )
82 }
83 }
84
85 Ok(statement)
86 }
87
88 fn current_statement(&self) -> *mut sqlite3_stmt {
89 *self.raw_statements.get(self.current_statement).unwrap()
90 }
91
92 pub fn reset(&mut self) {
93 unsafe {
94 for raw_statement in self.raw_statements.iter() {
95 sqlite3_reset(*raw_statement);
96 }
97 }
98 self.current_statement = 0;
99 }
100
101 pub fn parameter_count(&self) -> i32 {
102 unsafe {
103 self.raw_statements
104 .iter()
105 .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
106 .max()
107 .unwrap_or(0)
108 }
109 }
110
111 fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
112 let mut any_succeed = false;
113 unsafe {
114 for raw_statement in self.raw_statements.iter() {
115 if index <= sqlite3_bind_parameter_count(*raw_statement) {
116 bind(raw_statement);
117 self.connection
118 .last_error()
119 .with_context(|| format!("Failed to bind value at index {index}"))?;
120 any_succeed = true;
121 } else {
122 continue;
123 }
124 }
125 }
126 if any_succeed {
127 Ok(())
128 } else {
129 anyhow::bail!("Failed to bind parameters")
130 }
131 }
132
133 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
134 let index = index as c_int;
135 let blob_pointer = blob.as_ptr() as *const _;
136 let len = blob.len() as c_int;
137
138 self.bind_index_with(index, |raw_statement| unsafe {
139 sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
140 })
141 }
142
143 pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
144 let index = index as c_int;
145 let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
146
147 self.connection
148 .last_error()
149 .with_context(|| format!("Failed to read blob at index {index}"))?;
150 if pointer.is_null() {
151 return Ok(&[]);
152 }
153 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
154 self.connection
155 .last_error()
156 .with_context(|| format!("Failed to read length of blob at index {index}"))?;
157
158 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
159 }
160
161 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
162 let index = index as c_int;
163
164 self.bind_index_with(index, |raw_statement| unsafe {
165 sqlite3_bind_double(*raw_statement, index, double);
166 })
167 }
168
169 pub fn column_double(&self, index: i32) -> Result<f64> {
170 let index = index as c_int;
171 let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
172 self.connection
173 .last_error()
174 .with_context(|| format!("Failed to read double at index {index}"))?;
175 Ok(result)
176 }
177
178 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
179 let index = index as c_int;
180 self.bind_index_with(index, |raw_statement| unsafe {
181 sqlite3_bind_int(*raw_statement, index, int);
182 })
183 }
184
185 pub fn column_int(&self, index: i32) -> Result<i32> {
186 let index = index as c_int;
187 let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
188 self.connection
189 .last_error()
190 .with_context(|| format!("Failed to read int at index {index}"))?;
191 Ok(result)
192 }
193
194 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
195 let index = index as c_int;
196 self.bind_index_with(index, |raw_statement| unsafe {
197 sqlite3_bind_int64(*raw_statement, index, int);
198 })
199 }
200
201 pub fn column_int64(&self, index: i32) -> Result<i64> {
202 let index = index as c_int;
203 let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
204 self.connection
205 .last_error()
206 .with_context(|| format!("Failed to read i64 at index {index}"))?;
207 Ok(result)
208 }
209
210 pub fn bind_null(&self, index: i32) -> Result<()> {
211 let index = index as c_int;
212 self.bind_index_with(index, |raw_statement| unsafe {
213 sqlite3_bind_null(*raw_statement, index);
214 })
215 }
216
217 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
218 let index = index as c_int;
219 let text_pointer = text.as_ptr() as *const _;
220 let len = text.len() as c_int;
221
222 self.bind_index_with(index, |raw_statement| unsafe {
223 sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
224 })
225 }
226
227 pub fn column_text(&mut self, index: i32) -> Result<&str> {
228 let index = index as c_int;
229 let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
230
231 self.connection
232 .last_error()
233 .with_context(|| format!("Failed to read text from column {index}"))?;
234 if pointer.is_null() {
235 return Ok("");
236 }
237 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
238 self.connection
239 .last_error()
240 .with_context(|| format!("Failed to read text length at {index}"))?;
241
242 let slice = unsafe { slice::from_raw_parts(pointer, len) };
243 Ok(str::from_utf8(slice)?)
244 }
245
246 pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
247 debug_assert!(index > 0);
248 value.bind(self, index)
249 }
250
251 pub fn column<T: Column>(&mut self) -> Result<T> {
252 Ok(T::column(self, 0)?.0)
253 }
254
255 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
256 let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
257 self.connection.last_error()?;
258 match result {
259 SQLITE_INTEGER => Ok(SqlType::Integer),
260 SQLITE_FLOAT => Ok(SqlType::Float),
261 SQLITE_TEXT => Ok(SqlType::Text),
262 SQLITE_BLOB => Ok(SqlType::Blob),
263 SQLITE_NULL => Ok(SqlType::Null),
264 _ => anyhow::bail!("Column type returned was incorrect"),
265 }
266 }
267
268 pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
269 self.bind(bindings, 1)?;
270 Ok(self)
271 }
272
273 fn step(&mut self) -> Result<StepResult> {
274 match unsafe { sqlite3_step(self.current_statement()) } {
275 SQLITE_ROW => Ok(StepResult::Row),
276 SQLITE_DONE => {
277 if self.current_statement >= self.raw_statements.len() - 1 {
278 Ok(StepResult::Done)
279 } else {
280 self.current_statement += 1;
281 self.step()
282 }
283 }
284 SQLITE_MISUSE => anyhow::bail!("Statement step returned SQLITE_MISUSE"),
285 _other_error => {
286 self.connection.last_error()?;
287 unreachable!("Step returned error code and last error failed to catch it");
288 }
289 }
290 }
291
292 pub fn exec(&mut self) -> Result<()> {
293 fn logic(this: &mut Statement) -> Result<()> {
294 while this.step()? == StepResult::Row {}
295 Ok(())
296 }
297 let result = logic(self);
298 self.reset();
299 result
300 }
301
302 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
303 fn logic<R>(
304 this: &mut Statement,
305 mut callback: impl FnMut(&mut Statement) -> Result<R>,
306 ) -> Result<Vec<R>> {
307 let mut mapped_rows = Vec::new();
308 while this.step()? == StepResult::Row {
309 mapped_rows.push(callback(this)?);
310 }
311 Ok(mapped_rows)
312 }
313
314 let result = logic(self, callback);
315 self.reset();
316 result
317 }
318
319 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
320 self.map(|s| s.column::<R>())
321 }
322
323 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
324 fn logic<R>(
325 this: &mut Statement,
326 callback: impl FnOnce(&mut Statement) -> Result<R>,
327 ) -> Result<R> {
328 println!("{:?}", std::any::type_name::<R>());
329 anyhow::ensure!(
330 this.step()? == StepResult::Row,
331 "single called with query that returns no rows."
332 );
333 let result = callback(this)?;
334
335 anyhow::ensure!(
336 this.step()? == StepResult::Done,
337 "single called with a query that returns more than one row."
338 );
339
340 Ok(result)
341 }
342 let result = logic(self, callback);
343 self.reset();
344 result
345 }
346
347 pub fn row<R: Column>(&mut self) -> Result<R> {
348 self.single(|this| this.column::<R>())
349 }
350
351 pub fn maybe<R>(
352 &mut self,
353 callback: impl FnOnce(&mut Statement) -> Result<R>,
354 ) -> Result<Option<R>> {
355 fn logic<R>(
356 this: &mut Statement,
357 callback: impl FnOnce(&mut Statement) -> Result<R>,
358 ) -> Result<Option<R>> {
359 if this.step().context("Failed on step call")? != StepResult::Row {
360 return Ok(None);
361 }
362
363 let result = callback(this)
364 .map(|r| Some(r))
365 .context("Failed to parse row result")?;
366
367 anyhow::ensure!(
368 this.step().context("Second step call")? == StepResult::Done,
369 "maybe called with a query that returns more than one row."
370 );
371
372 Ok(result)
373 }
374 let result = logic(self, callback);
375 self.reset();
376 result
377 }
378
379 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
380 self.maybe(|this| this.column::<R>())
381 }
382}
383
384impl Drop for Statement<'_> {
385 fn drop(&mut self) {
386 unsafe {
387 for raw_statement in self.raw_statements.iter() {
388 sqlite3_finalize(*raw_statement);
389 }
390 }
391 }
392}
393
394#[cfg(test)]
395mod test {
396 use indoc::indoc;
397
398 use crate::{
399 connection::Connection,
400 statement::{Statement, StepResult},
401 };
402
403 #[test]
404 fn binding_multiple_statements_with_parameter_gaps() {
405 let connection =
406 Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
407
408 connection
409 .exec(indoc! {"
410 CREATE TABLE test (
411 col INTEGER
412 )"})
413 .unwrap()()
414 .unwrap();
415
416 let statement = Statement::prepare(
417 &connection,
418 indoc! {"
419 INSERT INTO test(col) VALUES (?3);
420 SELECT * FROM test WHERE col = ?1"},
421 )
422 .unwrap();
423
424 statement
425 .bind_int(1, 1)
426 .expect("Could not bind parameter to first index");
427 statement
428 .bind_int(2, 2)
429 .expect("Could not bind parameter to second index");
430 statement
431 .bind_int(3, 3)
432 .expect("Could not bind parameter to third index");
433 }
434
435 #[test]
436 fn blob_round_trips() {
437 let connection1 = Connection::open_memory(Some("blob_round_trips"));
438 connection1
439 .exec(indoc! {"
440 CREATE TABLE blobs (
441 data BLOB
442 )"})
443 .unwrap()()
444 .unwrap();
445
446 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
447
448 let mut write =
449 Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
450 write.bind_blob(1, blob).unwrap();
451 assert_eq!(write.step().unwrap(), StepResult::Done);
452
453 // Read the blob from the
454 let connection2 = Connection::open_memory(Some("blob_round_trips"));
455 let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
456 assert_eq!(read.step().unwrap(), StepResult::Row);
457 assert_eq!(read.column_blob(0).unwrap(), blob);
458 assert_eq!(read.step().unwrap(), StepResult::Done);
459
460 // Delete the added blob and verify its deleted on the other side
461 connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
462 let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
463 assert_eq!(read.step().unwrap(), StepResult::Done);
464 }
465
466 #[test]
467 pub fn maybe_returns_options() {
468 let connection = Connection::open_memory(Some("maybe_returns_options"));
469 connection
470 .exec(indoc! {"
471 CREATE TABLE texts (
472 text TEXT
473 )"})
474 .unwrap()()
475 .unwrap();
476
477 assert!(
478 connection
479 .select_row::<String>("SELECT text FROM texts")
480 .unwrap()()
481 .unwrap()
482 .is_none()
483 );
484
485 let text_to_insert = "This is a test";
486
487 connection
488 .exec_bound("INSERT INTO texts VALUES (?)")
489 .unwrap()(text_to_insert)
490 .unwrap();
491
492 assert_eq!(
493 connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
494 Some(text_to_insert.to_string())
495 );
496 }
497}