statement.rs

  1use std::ffi::{CStr, CString, c_int};
  2use std::marker::PhantomData;
  3use std::{ptr, slice, str};
  4
  5use anyhow::{Context, Result, anyhow, bail};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    /// vector of pointers to the raw SQLite statement objects.
 13    /// it holds the actual prepared statements that will be executed.
 14    pub raw_statements: Vec<*mut sqlite3_stmt>,
 15    /// Index of the current statement being executed from the `raw_statements` vector.
 16    current_statement: usize,
 17    /// A reference to the database connection.
 18    /// This is used to execute the statements and check for errors.
 19    connection: &'a Connection,
 20    ///Indicates that the `Statement` struct is tied to the lifetime of the SQLite statement
 21    phantom: PhantomData<sqlite3_stmt>,
 22}
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 25pub enum StepResult {
 26    Row,
 27    Done,
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 31pub enum SqlType {
 32    Text,
 33    Integer,
 34    Blob,
 35    Float,
 36    Null,
 37}
 38
 39impl<'a> Statement<'a> {
 40    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 41        let mut statement = Self {
 42            raw_statements: Default::default(),
 43            current_statement: 0,
 44            connection,
 45            phantom: PhantomData,
 46        };
 47        unsafe {
 48            let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
 49            let mut remaining_sql = sql.as_c_str();
 50            while {
 51                let remaining_sql_str = remaining_sql
 52                    .to_str()
 53                    .context("Parsing remaining sql")?
 54                    .trim();
 55                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 56            } {
 57                let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
 58                let mut remaining_sql_ptr = ptr::null();
 59                sqlite3_prepare_v2(
 60                    connection.sqlite3,
 61                    remaining_sql.as_ptr(),
 62                    -1,
 63                    &mut raw_statement,
 64                    &mut remaining_sql_ptr,
 65                );
 66
 67                connection.last_error().with_context(|| {
 68                    format!("Prepare call failed for query:\n{}", query.as_ref())
 69                })?;
 70
 71                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
 72                statement.raw_statements.push(raw_statement);
 73
 74                if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
 75                    let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
 76
 77                    bail!(
 78                        "Write statement prepared with connection that is not write capable. SQL:\n{} ",
 79                        sql.to_str()?
 80                    )
 81                }
 82            }
 83        }
 84
 85        Ok(statement)
 86    }
 87
 88    fn current_statement(&self) -> *mut sqlite3_stmt {
 89        *self.raw_statements.get(self.current_statement).unwrap()
 90    }
 91
 92    pub fn reset(&mut self) {
 93        unsafe {
 94            for raw_statement in self.raw_statements.iter() {
 95                sqlite3_reset(*raw_statement);
 96            }
 97        }
 98        self.current_statement = 0;
 99    }
100
101    pub fn parameter_count(&self) -> i32 {
102        unsafe {
103            self.raw_statements
104                .iter()
105                .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
106                .max()
107                .unwrap_or(0)
108        }
109    }
110
111    fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
112        let mut any_succeed = false;
113        unsafe {
114            for raw_statement in self.raw_statements.iter() {
115                if index <= sqlite3_bind_parameter_count(*raw_statement) {
116                    bind(raw_statement);
117                    self.connection
118                        .last_error()
119                        .with_context(|| format!("Failed to bind value at index {index}"))?;
120                    any_succeed = true;
121                } else {
122                    continue;
123                }
124            }
125        }
126        if any_succeed {
127            Ok(())
128        } else {
129            Err(anyhow!("Failed to bind parameters"))
130        }
131    }
132
133    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
134        let index = index as c_int;
135        let blob_pointer = blob.as_ptr() as *const _;
136        let len = blob.len() as c_int;
137
138        self.bind_index_with(index, |raw_statement| unsafe {
139            sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
140        })
141    }
142
143    pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
144        let index = index as c_int;
145        let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
146
147        self.connection
148            .last_error()
149            .with_context(|| format!("Failed to read blob at index {index}"))?;
150        if pointer.is_null() {
151            return Ok(&[]);
152        }
153        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
154        self.connection
155            .last_error()
156            .with_context(|| format!("Failed to read length of blob at index {index}"))?;
157
158        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
159    }
160
161    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
162        let index = index as c_int;
163
164        self.bind_index_with(index, |raw_statement| unsafe {
165            sqlite3_bind_double(*raw_statement, index, double);
166        })
167    }
168
169    pub fn column_double(&self, index: i32) -> Result<f64> {
170        let index = index as c_int;
171        let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
172        self.connection
173            .last_error()
174            .with_context(|| format!("Failed to read double at index {index}"))?;
175        Ok(result)
176    }
177
178    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
179        let index = index as c_int;
180        self.bind_index_with(index, |raw_statement| unsafe {
181            sqlite3_bind_int(*raw_statement, index, int);
182        })
183    }
184
185    pub fn column_int(&self, index: i32) -> Result<i32> {
186        let index = index as c_int;
187        let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
188        self.connection
189            .last_error()
190            .with_context(|| format!("Failed to read int at index {index}"))?;
191        Ok(result)
192    }
193
194    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
195        let index = index as c_int;
196        self.bind_index_with(index, |raw_statement| unsafe {
197            sqlite3_bind_int64(*raw_statement, index, int);
198        })
199    }
200
201    pub fn column_int64(&self, index: i32) -> Result<i64> {
202        let index = index as c_int;
203        let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
204        self.connection
205            .last_error()
206            .with_context(|| format!("Failed to read i64 at index {index}"))?;
207        Ok(result)
208    }
209
210    pub fn bind_null(&self, index: i32) -> Result<()> {
211        let index = index as c_int;
212        self.bind_index_with(index, |raw_statement| unsafe {
213            sqlite3_bind_null(*raw_statement, index);
214        })
215    }
216
217    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
218        let index = index as c_int;
219        let text_pointer = text.as_ptr() as *const _;
220        let len = text.len() as c_int;
221
222        self.bind_index_with(index, |raw_statement| unsafe {
223            sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
224        })
225    }
226
227    pub fn column_text(&mut self, index: i32) -> Result<&str> {
228        let index = index as c_int;
229        let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
230
231        self.connection
232            .last_error()
233            .with_context(|| format!("Failed to read text from column {index}"))?;
234        if pointer.is_null() {
235            return Ok("");
236        }
237        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
238        self.connection
239            .last_error()
240            .with_context(|| format!("Failed to read text length at {index}"))?;
241
242        let slice = unsafe { slice::from_raw_parts(pointer, len) };
243        Ok(str::from_utf8(slice)?)
244    }
245
246    pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
247        debug_assert!(index > 0);
248        value.bind(self, index)
249    }
250
251    pub fn column<T: Column>(&mut self) -> Result<T> {
252        Ok(T::column(self, 0)?.0)
253    }
254
255    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
256        let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
257        self.connection.last_error()?;
258        match result {
259            SQLITE_INTEGER => Ok(SqlType::Integer),
260            SQLITE_FLOAT => Ok(SqlType::Float),
261            SQLITE_TEXT => Ok(SqlType::Text),
262            SQLITE_BLOB => Ok(SqlType::Blob),
263            SQLITE_NULL => Ok(SqlType::Null),
264            _ => Err(anyhow!("Column type returned was incorrect ")),
265        }
266    }
267
268    pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
269        self.bind(bindings, 1)?;
270        Ok(self)
271    }
272
273    fn step(&mut self) -> Result<StepResult> {
274        unsafe {
275            match sqlite3_step(self.current_statement()) {
276                SQLITE_ROW => Ok(StepResult::Row),
277                SQLITE_DONE => {
278                    if self.current_statement >= self.raw_statements.len() - 1 {
279                        Ok(StepResult::Done)
280                    } else {
281                        self.current_statement += 1;
282                        self.step()
283                    }
284                }
285                SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
286                _other_error => {
287                    self.connection.last_error()?;
288                    unreachable!("Step returned error code and last error failed to catch it");
289                }
290            }
291        }
292    }
293
294    pub fn exec(&mut self) -> Result<()> {
295        fn logic(this: &mut Statement) -> Result<()> {
296            while this.step()? == StepResult::Row {}
297            Ok(())
298        }
299        let result = logic(self);
300        self.reset();
301        result
302    }
303
304    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
305        fn logic<R>(
306            this: &mut Statement,
307            mut callback: impl FnMut(&mut Statement) -> Result<R>,
308        ) -> Result<Vec<R>> {
309            let mut mapped_rows = Vec::new();
310            while this.step()? == StepResult::Row {
311                mapped_rows.push(callback(this)?);
312            }
313            Ok(mapped_rows)
314        }
315
316        let result = logic(self, callback);
317        self.reset();
318        result
319    }
320
321    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
322        self.map(|s| s.column::<R>())
323    }
324
325    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
326        fn logic<R>(
327            this: &mut Statement,
328            callback: impl FnOnce(&mut Statement) -> Result<R>,
329        ) -> Result<R> {
330            println!("{:?}", std::any::type_name::<R>());
331            if this.step()? != StepResult::Row {
332                return Err(anyhow!("single called with query that returns no rows."));
333            }
334            let result = callback(this)?;
335
336            if this.step()? != StepResult::Done {
337                return Err(anyhow!(
338                    "single called with a query that returns more than one row."
339                ));
340            }
341
342            Ok(result)
343        }
344        let result = logic(self, callback);
345        self.reset();
346        result
347    }
348
349    pub fn row<R: Column>(&mut self) -> Result<R> {
350        self.single(|this| this.column::<R>())
351    }
352
353    pub fn maybe<R>(
354        &mut self,
355        callback: impl FnOnce(&mut Statement) -> Result<R>,
356    ) -> Result<Option<R>> {
357        fn logic<R>(
358            this: &mut Statement,
359            callback: impl FnOnce(&mut Statement) -> Result<R>,
360        ) -> Result<Option<R>> {
361            if this.step().context("Failed on step call")? != StepResult::Row {
362                return Ok(None);
363            }
364
365            let result = callback(this)
366                .map(|r| Some(r))
367                .context("Failed to parse row result")?;
368
369            if this.step().context("Second step call")? != StepResult::Done {
370                return Err(anyhow!(
371                    "maybe called with a query that returns more than one row."
372                ));
373            }
374
375            Ok(result)
376        }
377        let result = logic(self, callback);
378        self.reset();
379        result
380    }
381
382    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
383        self.maybe(|this| this.column::<R>())
384    }
385}
386
387impl Drop for Statement<'_> {
388    fn drop(&mut self) {
389        unsafe {
390            for raw_statement in self.raw_statements.iter() {
391                sqlite3_finalize(*raw_statement);
392            }
393        }
394    }
395}
396
397#[cfg(test)]
398mod test {
399    use indoc::indoc;
400
401    use crate::{
402        connection::Connection,
403        statement::{Statement, StepResult},
404    };
405
406    #[test]
407    fn binding_multiple_statements_with_parameter_gaps() {
408        let connection =
409            Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
410
411        connection
412            .exec(indoc! {"
413            CREATE TABLE test (
414                col INTEGER
415            )"})
416            .unwrap()()
417        .unwrap();
418
419        let statement = Statement::prepare(
420            &connection,
421            indoc! {"
422                INSERT INTO test(col) VALUES (?3);
423                SELECT * FROM test WHERE col = ?1"},
424        )
425        .unwrap();
426
427        statement
428            .bind_int(1, 1)
429            .expect("Could not bind parameter to first index");
430        statement
431            .bind_int(2, 2)
432            .expect("Could not bind parameter to second index");
433        statement
434            .bind_int(3, 3)
435            .expect("Could not bind parameter to third index");
436    }
437
438    #[test]
439    fn blob_round_trips() {
440        let connection1 = Connection::open_memory(Some("blob_round_trips"));
441        connection1
442            .exec(indoc! {"
443                CREATE TABLE blobs (
444                    data BLOB
445                )"})
446            .unwrap()()
447        .unwrap();
448
449        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
450
451        let mut write =
452            Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
453        write.bind_blob(1, blob).unwrap();
454        assert_eq!(write.step().unwrap(), StepResult::Done);
455
456        // Read the blob from the
457        let connection2 = Connection::open_memory(Some("blob_round_trips"));
458        let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
459        assert_eq!(read.step().unwrap(), StepResult::Row);
460        assert_eq!(read.column_blob(0).unwrap(), blob);
461        assert_eq!(read.step().unwrap(), StepResult::Done);
462
463        // Delete the added blob and verify its deleted on the other side
464        connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
465        let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
466        assert_eq!(read.step().unwrap(), StepResult::Done);
467    }
468
469    #[test]
470    pub fn maybe_returns_options() {
471        let connection = Connection::open_memory(Some("maybe_returns_options"));
472        connection
473            .exec(indoc! {"
474                CREATE TABLE texts (
475                    text TEXT
476                )"})
477            .unwrap()()
478        .unwrap();
479
480        assert!(
481            connection
482                .select_row::<String>("SELECT text FROM texts")
483                .unwrap()()
484            .unwrap()
485            .is_none()
486        );
487
488        let text_to_insert = "This is a test";
489
490        connection
491            .exec_bound("INSERT INTO texts VALUES (?)")
492            .unwrap()(text_to_insert)
493        .unwrap();
494
495        assert_eq!(
496            connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
497            Some(text_to_insert.to_string())
498        );
499    }
500}