statement.rs

  1use std::ffi::{c_int, CStr, CString};
  2use std::marker::PhantomData;
  3use std::{ptr, slice, str};
  4
  5use anyhow::{anyhow, bail, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    /// vector of pointers to the raw SQLite statement objects.
 13    /// it holds the actual prepared statements that will be executed.
 14    raw_statements: Vec<*mut sqlite3_stmt>,
 15    /// Index of the current statement being executed from the `raw_statements` vector.
 16    current_statement: usize,
 17    /// A reference to the database connection.
 18    /// This is used to execute the statements and check for errors.
 19    connection: &'a Connection,
 20    ///Indicates that the `Statement` struct is tied to the lifetime of the SQLite statement
 21    phantom: PhantomData<sqlite3_stmt>,
 22}
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 25pub enum StepResult {
 26    Row,
 27    Done,
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 31pub enum SqlType {
 32    Text,
 33    Integer,
 34    Blob,
 35    Float,
 36    Null,
 37}
 38
 39impl<'a> Statement<'a> {
 40    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 41        let mut statement = Self {
 42            raw_statements: Default::default(),
 43            current_statement: 0,
 44            connection,
 45            phantom: PhantomData,
 46        };
 47        unsafe {
 48            let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
 49            let mut remaining_sql = sql.as_c_str();
 50            while {
 51                let remaining_sql_str = remaining_sql
 52                    .to_str()
 53                    .context("Parsing remaining sql")?
 54                    .trim();
 55                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 56            } {
 57                let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
 58                let mut remaining_sql_ptr = ptr::null();
 59                sqlite3_prepare_v2(
 60                    connection.sqlite3,
 61                    remaining_sql.as_ptr(),
 62                    -1,
 63                    &mut raw_statement,
 64                    &mut remaining_sql_ptr,
 65                );
 66
 67                connection.last_error().with_context(|| {
 68                    format!("Prepare call failed for query:\n{}", query.as_ref())
 69                })?;
 70
 71                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
 72                statement.raw_statements.push(raw_statement);
 73
 74                if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
 75                    let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
 76
 77                    bail!(
 78                        "Write statement prepared with connection that is not write capable. SQL:\n{} ",
 79                        sql.to_str()?)
 80                }
 81            }
 82        }
 83
 84        Ok(statement)
 85    }
 86
 87    fn current_statement(&self) -> *mut sqlite3_stmt {
 88        *self.raw_statements.get(self.current_statement).unwrap()
 89    }
 90
 91    pub fn reset(&mut self) {
 92        unsafe {
 93            for raw_statement in self.raw_statements.iter() {
 94                sqlite3_reset(*raw_statement);
 95            }
 96        }
 97        self.current_statement = 0;
 98    }
 99
100    pub fn parameter_count(&self) -> i32 {
101        unsafe {
102            self.raw_statements
103                .iter()
104                .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
105                .max()
106                .unwrap_or(0)
107        }
108    }
109
110    fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
111        let mut any_succeed = false;
112        unsafe {
113            for raw_statement in self.raw_statements.iter() {
114                if index <= sqlite3_bind_parameter_count(*raw_statement) {
115                    bind(raw_statement);
116                    self.connection
117                        .last_error()
118                        .with_context(|| format!("Failed to bind value at index {index}"))?;
119                    any_succeed = true;
120                } else {
121                    continue;
122                }
123            }
124        }
125        if any_succeed {
126            Ok(())
127        } else {
128            Err(anyhow!("Failed to bind parameters"))
129        }
130    }
131
132    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
133        let index = index as c_int;
134        let blob_pointer = blob.as_ptr() as *const _;
135        let len = blob.len() as c_int;
136
137        self.bind_index_with(index, |raw_statement| unsafe {
138            sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
139        })
140    }
141
142    pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
143        let index = index as c_int;
144        let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
145
146        self.connection
147            .last_error()
148            .with_context(|| format!("Failed to read blob at index {index}"))?;
149        if pointer.is_null() {
150            return Ok(&[]);
151        }
152        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
153        self.connection
154            .last_error()
155            .with_context(|| format!("Failed to read length of blob at index {index}"))?;
156
157        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
158    }
159
160    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
161        let index = index as c_int;
162
163        self.bind_index_with(index, |raw_statement| unsafe {
164            sqlite3_bind_double(*raw_statement, index, double);
165        })
166    }
167
168    pub fn column_double(&self, index: i32) -> Result<f64> {
169        let index = index as c_int;
170        let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
171        self.connection
172            .last_error()
173            .with_context(|| format!("Failed to read double at index {index}"))?;
174        Ok(result)
175    }
176
177    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
178        let index = index as c_int;
179        self.bind_index_with(index, |raw_statement| unsafe {
180            sqlite3_bind_int(*raw_statement, index, int);
181        })
182    }
183
184    pub fn column_int(&self, index: i32) -> Result<i32> {
185        let index = index as c_int;
186        let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
187        self.connection
188            .last_error()
189            .with_context(|| format!("Failed to read int at index {index}"))?;
190        Ok(result)
191    }
192
193    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
194        let index = index as c_int;
195        self.bind_index_with(index, |raw_statement| unsafe {
196            sqlite3_bind_int64(*raw_statement, index, int);
197        })
198    }
199
200    pub fn column_int64(&self, index: i32) -> Result<i64> {
201        let index = index as c_int;
202        let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
203        self.connection
204            .last_error()
205            .with_context(|| format!("Failed to read i64 at index {index}"))?;
206        Ok(result)
207    }
208
209    pub fn bind_null(&self, index: i32) -> Result<()> {
210        let index = index as c_int;
211        self.bind_index_with(index, |raw_statement| unsafe {
212            sqlite3_bind_null(*raw_statement, index);
213        })
214    }
215
216    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
217        let index = index as c_int;
218        let text_pointer = text.as_ptr() as *const _;
219        let len = text.len() as c_int;
220
221        self.bind_index_with(index, |raw_statement| unsafe {
222            sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
223        })
224    }
225
226    pub fn column_text(&mut self, index: i32) -> Result<&str> {
227        let index = index as c_int;
228        let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
229
230        self.connection
231            .last_error()
232            .with_context(|| format!("Failed to read text from column {index}"))?;
233        if pointer.is_null() {
234            return Ok("");
235        }
236        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
237        self.connection
238            .last_error()
239            .with_context(|| format!("Failed to read text length at {index}"))?;
240
241        let slice = unsafe { slice::from_raw_parts(pointer, len) };
242        Ok(str::from_utf8(slice)?)
243    }
244
245    pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
246        debug_assert!(index > 0);
247        value.bind(self, index)
248    }
249
250    pub fn column<T: Column>(&mut self) -> Result<T> {
251        Ok(T::column(self, 0)?.0)
252    }
253
254    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
255        let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
256        self.connection.last_error()?;
257        match result {
258            SQLITE_INTEGER => Ok(SqlType::Integer),
259            SQLITE_FLOAT => Ok(SqlType::Float),
260            SQLITE_TEXT => Ok(SqlType::Text),
261            SQLITE_BLOB => Ok(SqlType::Blob),
262            SQLITE_NULL => Ok(SqlType::Null),
263            _ => Err(anyhow!("Column type returned was incorrect ")),
264        }
265    }
266
267    pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
268        self.bind(bindings, 1)?;
269        Ok(self)
270    }
271
272    fn step(&mut self) -> Result<StepResult> {
273        unsafe {
274            match sqlite3_step(self.current_statement()) {
275                SQLITE_ROW => Ok(StepResult::Row),
276                SQLITE_DONE => {
277                    if self.current_statement >= self.raw_statements.len() - 1 {
278                        Ok(StepResult::Done)
279                    } else {
280                        self.current_statement += 1;
281                        self.step()
282                    }
283                }
284                SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
285                _other_error => {
286                    self.connection.last_error()?;
287                    unreachable!("Step returned error code and last error failed to catch it");
288                }
289            }
290        }
291    }
292
293    pub fn exec(&mut self) -> Result<()> {
294        fn logic(this: &mut Statement) -> Result<()> {
295            while this.step()? == StepResult::Row {}
296            Ok(())
297        }
298        let result = logic(self);
299        self.reset();
300        result
301    }
302
303    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
304        fn logic<R>(
305            this: &mut Statement,
306            mut callback: impl FnMut(&mut Statement) -> Result<R>,
307        ) -> Result<Vec<R>> {
308            let mut mapped_rows = Vec::new();
309            while this.step()? == StepResult::Row {
310                mapped_rows.push(callback(this)?);
311            }
312            Ok(mapped_rows)
313        }
314
315        let result = logic(self, callback);
316        self.reset();
317        result
318    }
319
320    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
321        self.map(|s| s.column::<R>())
322    }
323
324    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
325        fn logic<R>(
326            this: &mut Statement,
327            callback: impl FnOnce(&mut Statement) -> Result<R>,
328        ) -> Result<R> {
329            println!("{:?}", std::any::type_name::<R>());
330            if this.step()? != StepResult::Row {
331                return Err(anyhow!("single called with query that returns no rows."));
332            }
333            let result = callback(this)?;
334
335            if this.step()? != StepResult::Done {
336                return Err(anyhow!(
337                    "single called with a query that returns more than one row."
338                ));
339            }
340
341            Ok(result)
342        }
343        let result = logic(self, callback);
344        self.reset();
345        result
346    }
347
348    pub fn row<R: Column>(&mut self) -> Result<R> {
349        self.single(|this| this.column::<R>())
350    }
351
352    pub fn maybe<R>(
353        &mut self,
354        callback: impl FnOnce(&mut Statement) -> Result<R>,
355    ) -> Result<Option<R>> {
356        fn logic<R>(
357            this: &mut Statement,
358            callback: impl FnOnce(&mut Statement) -> Result<R>,
359        ) -> Result<Option<R>> {
360            if this.step().context("Failed on step call")? != StepResult::Row {
361                return Ok(None);
362            }
363
364            let result = callback(this)
365                .map(|r| Some(r))
366                .context("Failed to parse row result")?;
367
368            if this.step().context("Second step call")? != StepResult::Done {
369                return Err(anyhow!(
370                    "maybe called with a query that returns more than one row."
371                ));
372            }
373
374            Ok(result)
375        }
376        let result = logic(self, callback);
377        self.reset();
378        result
379    }
380
381    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
382        self.maybe(|this| this.column::<R>())
383    }
384}
385
386impl Drop for Statement<'_> {
387    fn drop(&mut self) {
388        unsafe {
389            for raw_statement in self.raw_statements.iter() {
390                sqlite3_finalize(*raw_statement);
391            }
392        }
393    }
394}
395
396#[cfg(test)]
397mod test {
398    use indoc::indoc;
399
400    use crate::{
401        connection::Connection,
402        statement::{Statement, StepResult},
403    };
404
405    #[test]
406    fn binding_multiple_statements_with_parameter_gaps() {
407        let connection =
408            Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
409
410        connection
411            .exec(indoc! {"
412            CREATE TABLE test (
413                col INTEGER
414            )"})
415            .unwrap()()
416        .unwrap();
417
418        let statement = Statement::prepare(
419            &connection,
420            indoc! {"
421                INSERT INTO test(col) VALUES (?3);
422                SELECT * FROM test WHERE col = ?1"},
423        )
424        .unwrap();
425
426        statement
427            .bind_int(1, 1)
428            .expect("Could not bind parameter to first index");
429        statement
430            .bind_int(2, 2)
431            .expect("Could not bind parameter to second index");
432        statement
433            .bind_int(3, 3)
434            .expect("Could not bind parameter to third index");
435    }
436
437    #[test]
438    fn blob_round_trips() {
439        let connection1 = Connection::open_memory(Some("blob_round_trips"));
440        connection1
441            .exec(indoc! {"
442                CREATE TABLE blobs (
443                    data BLOB
444                )"})
445            .unwrap()()
446        .unwrap();
447
448        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
449
450        let mut write =
451            Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
452        write.bind_blob(1, blob).unwrap();
453        assert_eq!(write.step().unwrap(), StepResult::Done);
454
455        // Read the blob from the
456        let connection2 = Connection::open_memory(Some("blob_round_trips"));
457        let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
458        assert_eq!(read.step().unwrap(), StepResult::Row);
459        assert_eq!(read.column_blob(0).unwrap(), blob);
460        assert_eq!(read.step().unwrap(), StepResult::Done);
461
462        // Delete the added blob and verify its deleted on the other side
463        connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
464        let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
465        assert_eq!(read.step().unwrap(), StepResult::Done);
466    }
467
468    #[test]
469    pub fn maybe_returns_options() {
470        let connection = Connection::open_memory(Some("maybe_returns_options"));
471        connection
472            .exec(indoc! {"
473                CREATE TABLE texts (
474                    text TEXT
475                )"})
476            .unwrap()()
477        .unwrap();
478
479        assert!(connection
480            .select_row::<String>("SELECT text FROM texts")
481            .unwrap()()
482        .unwrap()
483        .is_none());
484
485        let text_to_insert = "This is a test";
486
487        connection
488            .exec_bound("INSERT INTO texts VALUES (?)")
489            .unwrap()(text_to_insert)
490        .unwrap();
491
492        assert_eq!(
493            connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
494            Some(text_to_insert.to_string())
495        );
496    }
497}