1use std::ffi::{c_int, CStr, CString};
2use std::marker::PhantomData;
3use std::{ptr, slice, str};
4
5use anyhow::{anyhow, bail, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 raw_statements: Vec<*mut sqlite3_stmt>,
13 current_statement: usize,
14 connection: &'a Connection,
15 phantom: PhantomData<sqlite3_stmt>,
16}
17
18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub enum StepResult {
20 Row,
21 Done,
22}
23
24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
25pub enum SqlType {
26 Text,
27 Integer,
28 Blob,
29 Float,
30 Null,
31}
32
33impl<'a> Statement<'a> {
34 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
35 let mut statement = Self {
36 raw_statements: Default::default(),
37 current_statement: 0,
38 connection,
39 phantom: PhantomData,
40 };
41 unsafe {
42 let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
43 let mut remaining_sql = sql.as_c_str();
44 while {
45 let remaining_sql_str = remaining_sql
46 .to_str()
47 .context("Parsing remaining sql")?
48 .trim();
49 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
50 } {
51 let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
52 let mut remaining_sql_ptr = ptr::null();
53 sqlite3_prepare_v2(
54 connection.sqlite3,
55 remaining_sql.as_ptr(),
56 -1,
57 &mut raw_statement,
58 &mut remaining_sql_ptr,
59 );
60
61 connection.last_error().with_context(|| {
62 format!("Prepare call failed for query:\n{}", query.as_ref())
63 })?;
64
65 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
66 statement.raw_statements.push(raw_statement);
67
68 if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
69 let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
70
71 bail!(
72 "Write statement prepared with connection that is not write capable. SQL:\n{} ",
73 sql.to_str()?)
74 }
75 }
76 }
77
78 Ok(statement)
79 }
80
81 fn current_statement(&self) -> *mut sqlite3_stmt {
82 *self.raw_statements.get(self.current_statement).unwrap()
83 }
84
85 pub fn reset(&mut self) {
86 unsafe {
87 for raw_statement in self.raw_statements.iter() {
88 sqlite3_reset(*raw_statement);
89 }
90 }
91 self.current_statement = 0;
92 }
93
94 pub fn parameter_count(&self) -> i32 {
95 unsafe {
96 self.raw_statements
97 .iter()
98 .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
99 .max()
100 .unwrap_or(0)
101 }
102 }
103
104 fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
105 let mut any_succeed = false;
106 unsafe {
107 for raw_statement in self.raw_statements.iter() {
108 if index <= sqlite3_bind_parameter_count(*raw_statement) {
109 bind(raw_statement);
110 self.connection
111 .last_error()
112 .with_context(|| format!("Failed to bind value at index {index}"))?;
113 any_succeed = true;
114 } else {
115 continue;
116 }
117 }
118 }
119 if any_succeed {
120 Ok(())
121 } else {
122 Err(anyhow!("Failed to bind parameters"))
123 }
124 }
125
126 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
127 let index = index as c_int;
128 let blob_pointer = blob.as_ptr() as *const _;
129 let len = blob.len() as c_int;
130
131 self.bind_index_with(index, |raw_statement| unsafe {
132 sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
133 })
134 }
135
136 pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
137 let index = index as c_int;
138 let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
139
140 self.connection
141 .last_error()
142 .with_context(|| format!("Failed to read blob at index {index}"))?;
143 if pointer.is_null() {
144 return Ok(&[]);
145 }
146 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
147 self.connection
148 .last_error()
149 .with_context(|| format!("Failed to read length of blob at index {index}"))?;
150
151 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
152 }
153
154 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
155 let index = index as c_int;
156
157 self.bind_index_with(index, |raw_statement| unsafe {
158 sqlite3_bind_double(*raw_statement, index, double);
159 })
160 }
161
162 pub fn column_double(&self, index: i32) -> Result<f64> {
163 let index = index as c_int;
164 let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
165 self.connection
166 .last_error()
167 .with_context(|| format!("Failed to read double at index {index}"))?;
168 Ok(result)
169 }
170
171 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
172 let index = index as c_int;
173 self.bind_index_with(index, |raw_statement| unsafe {
174 sqlite3_bind_int(*raw_statement, index, int);
175 })
176 }
177
178 pub fn column_int(&self, index: i32) -> Result<i32> {
179 let index = index as c_int;
180 let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
181 self.connection
182 .last_error()
183 .with_context(|| format!("Failed to read int at index {index}"))?;
184 Ok(result)
185 }
186
187 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
188 let index = index as c_int;
189 self.bind_index_with(index, |raw_statement| unsafe {
190 sqlite3_bind_int64(*raw_statement, index, int);
191 })
192 }
193
194 pub fn column_int64(&self, index: i32) -> Result<i64> {
195 let index = index as c_int;
196 let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
197 self.connection
198 .last_error()
199 .with_context(|| format!("Failed to read i64 at index {index}"))?;
200 Ok(result)
201 }
202
203 pub fn bind_null(&self, index: i32) -> Result<()> {
204 let index = index as c_int;
205 self.bind_index_with(index, |raw_statement| unsafe {
206 sqlite3_bind_null(*raw_statement, index);
207 })
208 }
209
210 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
211 let index = index as c_int;
212 let text_pointer = text.as_ptr() as *const _;
213 let len = text.len() as c_int;
214
215 self.bind_index_with(index, |raw_statement| unsafe {
216 sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
217 })
218 }
219
220 pub fn column_text(&mut self, index: i32) -> Result<&str> {
221 let index = index as c_int;
222 let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
223
224 self.connection
225 .last_error()
226 .with_context(|| format!("Failed to read text from column {index}"))?;
227 if pointer.is_null() {
228 return Ok("");
229 }
230 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
231 self.connection
232 .last_error()
233 .with_context(|| format!("Failed to read text length at {index}"))?;
234
235 let slice = unsafe { slice::from_raw_parts(pointer, len) };
236 Ok(str::from_utf8(slice)?)
237 }
238
239 pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
240 debug_assert!(index > 0);
241 value.bind(self, index)
242 }
243
244 pub fn column<T: Column>(&mut self) -> Result<T> {
245 Ok(T::column(self, 0)?.0)
246 }
247
248 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
249 let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
250 self.connection.last_error()?;
251 match result {
252 SQLITE_INTEGER => Ok(SqlType::Integer),
253 SQLITE_FLOAT => Ok(SqlType::Float),
254 SQLITE_TEXT => Ok(SqlType::Text),
255 SQLITE_BLOB => Ok(SqlType::Blob),
256 SQLITE_NULL => Ok(SqlType::Null),
257 _ => Err(anyhow!("Column type returned was incorrect ")),
258 }
259 }
260
261 pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
262 self.bind(bindings, 1)?;
263 Ok(self)
264 }
265
266 fn step(&mut self) -> Result<StepResult> {
267 unsafe {
268 match sqlite3_step(self.current_statement()) {
269 SQLITE_ROW => Ok(StepResult::Row),
270 SQLITE_DONE => {
271 if self.current_statement >= self.raw_statements.len() - 1 {
272 Ok(StepResult::Done)
273 } else {
274 self.current_statement += 1;
275 self.step()
276 }
277 }
278 SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
279 _other_error => {
280 self.connection.last_error()?;
281 unreachable!("Step returned error code and last error failed to catch it");
282 }
283 }
284 }
285 }
286
287 pub fn exec(&mut self) -> Result<()> {
288 fn logic(this: &mut Statement) -> Result<()> {
289 while this.step()? == StepResult::Row {}
290 Ok(())
291 }
292 let result = logic(self);
293 self.reset();
294 result
295 }
296
297 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
298 fn logic<R>(
299 this: &mut Statement,
300 mut callback: impl FnMut(&mut Statement) -> Result<R>,
301 ) -> Result<Vec<R>> {
302 let mut mapped_rows = Vec::new();
303 while this.step()? == StepResult::Row {
304 mapped_rows.push(callback(this)?);
305 }
306 Ok(mapped_rows)
307 }
308
309 let result = logic(self, callback);
310 self.reset();
311 result
312 }
313
314 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
315 self.map(|s| s.column::<R>())
316 }
317
318 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
319 fn logic<R>(
320 this: &mut Statement,
321 callback: impl FnOnce(&mut Statement) -> Result<R>,
322 ) -> Result<R> {
323 println!("{:?}", std::any::type_name::<R>());
324 if this.step()? != StepResult::Row {
325 return Err(anyhow!("single called with query that returns no rows."));
326 }
327 let result = callback(this)?;
328
329 if this.step()? != StepResult::Done {
330 return Err(anyhow!(
331 "single called with a query that returns more than one row."
332 ));
333 }
334
335 Ok(result)
336 }
337 let result = logic(self, callback);
338 self.reset();
339 result
340 }
341
342 pub fn row<R: Column>(&mut self) -> Result<R> {
343 self.single(|this| this.column::<R>())
344 }
345
346 pub fn maybe<R>(
347 &mut self,
348 callback: impl FnOnce(&mut Statement) -> Result<R>,
349 ) -> Result<Option<R>> {
350 fn logic<R>(
351 this: &mut Statement,
352 callback: impl FnOnce(&mut Statement) -> Result<R>,
353 ) -> Result<Option<R>> {
354 if this.step().context("Failed on step call")? != StepResult::Row {
355 return Ok(None);
356 }
357
358 let result = callback(this)
359 .map(|r| Some(r))
360 .context("Failed to parse row result")?;
361
362 if this.step().context("Second step call")? != StepResult::Done {
363 return Err(anyhow!(
364 "maybe called with a query that returns more than one row."
365 ));
366 }
367
368 Ok(result)
369 }
370 let result = logic(self, callback);
371 self.reset();
372 result
373 }
374
375 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
376 self.maybe(|this| this.column::<R>())
377 }
378}
379
380impl Drop for Statement<'_> {
381 fn drop(&mut self) {
382 unsafe {
383 for raw_statement in self.raw_statements.iter() {
384 sqlite3_finalize(*raw_statement);
385 }
386 }
387 }
388}
389
390#[cfg(test)]
391mod test {
392 use indoc::indoc;
393
394 use crate::{
395 connection::Connection,
396 statement::{Statement, StepResult},
397 };
398
399 #[test]
400 fn binding_multiple_statements_with_parameter_gaps() {
401 let connection =
402 Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
403
404 connection
405 .exec(indoc! {"
406 CREATE TABLE test (
407 col INTEGER
408 )"})
409 .unwrap()()
410 .unwrap();
411
412 let statement = Statement::prepare(
413 &connection,
414 indoc! {"
415 INSERT INTO test(col) VALUES (?3);
416 SELECT * FROM test WHERE col = ?1"},
417 )
418 .unwrap();
419
420 statement
421 .bind_int(1, 1)
422 .expect("Could not bind parameter to first index");
423 statement
424 .bind_int(2, 2)
425 .expect("Could not bind parameter to second index");
426 statement
427 .bind_int(3, 3)
428 .expect("Could not bind parameter to third index");
429 }
430
431 #[test]
432 fn blob_round_trips() {
433 let connection1 = Connection::open_memory(Some("blob_round_trips"));
434 connection1
435 .exec(indoc! {"
436 CREATE TABLE blobs (
437 data BLOB
438 )"})
439 .unwrap()()
440 .unwrap();
441
442 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
443
444 let mut write =
445 Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
446 write.bind_blob(1, blob).unwrap();
447 assert_eq!(write.step().unwrap(), StepResult::Done);
448
449 // Read the blob from the
450 let connection2 = Connection::open_memory(Some("blob_round_trips"));
451 let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
452 assert_eq!(read.step().unwrap(), StepResult::Row);
453 assert_eq!(read.column_blob(0).unwrap(), blob);
454 assert_eq!(read.step().unwrap(), StepResult::Done);
455
456 // Delete the added blob and verify its deleted on the other side
457 connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
458 let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
459 assert_eq!(read.step().unwrap(), StepResult::Done);
460 }
461
462 #[test]
463 pub fn maybe_returns_options() {
464 let connection = Connection::open_memory(Some("maybe_returns_options"));
465 connection
466 .exec(indoc! {"
467 CREATE TABLE texts (
468 text TEXT
469 )"})
470 .unwrap()()
471 .unwrap();
472
473 assert!(connection
474 .select_row::<String>("SELECT text FROM texts")
475 .unwrap()()
476 .unwrap()
477 .is_none());
478
479 let text_to_insert = "This is a test";
480
481 connection
482 .exec_bound("INSERT INTO texts VALUES (?)")
483 .unwrap()(text_to_insert)
484 .unwrap();
485
486 assert_eq!(
487 connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
488 Some(text_to_insert.to_string())
489 );
490 }
491}