statement.rs

  1use std::ffi::{c_int, CStr, CString};
  2use std::marker::PhantomData;
  3use std::{ptr, slice, str};
  4
  5use anyhow::{anyhow, bail, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    raw_statements: Vec<*mut sqlite3_stmt>,
 13    current_statement: usize,
 14    connection: &'a Connection,
 15    phantom: PhantomData<sqlite3_stmt>,
 16}
 17
 18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 19pub enum StepResult {
 20    Row,
 21    Done,
 22}
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 25pub enum SqlType {
 26    Text,
 27    Integer,
 28    Blob,
 29    Float,
 30    Null,
 31}
 32
 33impl<'a> Statement<'a> {
 34    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 35        let mut statement = Self {
 36            raw_statements: Default::default(),
 37            current_statement: 0,
 38            connection,
 39            phantom: PhantomData,
 40        };
 41        unsafe {
 42            let sql = CString::new(query.as_ref()).context("Error creating cstr")?;
 43            let mut remaining_sql = sql.as_c_str();
 44            while {
 45                let remaining_sql_str = remaining_sql
 46                    .to_str()
 47                    .context("Parsing remaining sql")?
 48                    .trim();
 49                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 50            } {
 51                let mut raw_statement = ptr::null_mut::<sqlite3_stmt>();
 52                let mut remaining_sql_ptr = ptr::null();
 53                sqlite3_prepare_v2(
 54                    connection.sqlite3,
 55                    remaining_sql.as_ptr(),
 56                    -1,
 57                    &mut raw_statement,
 58                    &mut remaining_sql_ptr,
 59                );
 60
 61                connection.last_error().with_context(|| {
 62                    format!("Prepare call failed for query:\n{}", query.as_ref())
 63                })?;
 64
 65                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
 66                statement.raw_statements.push(raw_statement);
 67
 68                if !connection.can_write() && sqlite3_stmt_readonly(raw_statement) == 0 {
 69                    let sql = CStr::from_ptr(sqlite3_sql(raw_statement));
 70
 71                    bail!(
 72                        "Write statement prepared with connection that is not write capable. SQL:\n{} ",
 73                        sql.to_str()?)
 74                }
 75            }
 76        }
 77
 78        Ok(statement)
 79    }
 80
 81    fn current_statement(&self) -> *mut sqlite3_stmt {
 82        *self.raw_statements.get(self.current_statement).unwrap()
 83    }
 84
 85    pub fn reset(&mut self) {
 86        unsafe {
 87            for raw_statement in self.raw_statements.iter() {
 88                sqlite3_reset(*raw_statement);
 89            }
 90        }
 91        self.current_statement = 0;
 92    }
 93
 94    pub fn parameter_count(&self) -> i32 {
 95        unsafe {
 96            self.raw_statements
 97                .iter()
 98                .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
 99                .max()
100                .unwrap_or(0)
101        }
102    }
103
104    fn bind_index_with(&self, index: i32, bind: impl Fn(&*mut sqlite3_stmt)) -> Result<()> {
105        let mut any_succeed = false;
106        unsafe {
107            for raw_statement in self.raw_statements.iter() {
108                if index <= sqlite3_bind_parameter_count(*raw_statement) {
109                    bind(raw_statement);
110                    self.connection
111                        .last_error()
112                        .with_context(|| format!("Failed to bind value at index {index}"))?;
113                    any_succeed = true;
114                } else {
115                    continue;
116                }
117            }
118        }
119        if any_succeed {
120            Ok(())
121        } else {
122            Err(anyhow!("Failed to bind parameters"))
123        }
124    }
125
126    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
127        let index = index as c_int;
128        let blob_pointer = blob.as_ptr() as *const _;
129        let len = blob.len() as c_int;
130
131        self.bind_index_with(index, |raw_statement| unsafe {
132            sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
133        })
134    }
135
136    pub fn column_blob(&mut self, index: i32) -> Result<&[u8]> {
137        let index = index as c_int;
138        let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
139
140        self.connection
141            .last_error()
142            .with_context(|| format!("Failed to read blob at index {index}"))?;
143        if pointer.is_null() {
144            return Ok(&[]);
145        }
146        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
147        self.connection
148            .last_error()
149            .with_context(|| format!("Failed to read length of blob at index {index}"))?;
150
151        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
152    }
153
154    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
155        let index = index as c_int;
156
157        self.bind_index_with(index, |raw_statement| unsafe {
158            sqlite3_bind_double(*raw_statement, index, double);
159        })
160    }
161
162    pub fn column_double(&self, index: i32) -> Result<f64> {
163        let index = index as c_int;
164        let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
165        self.connection
166            .last_error()
167            .with_context(|| format!("Failed to read double at index {index}"))?;
168        Ok(result)
169    }
170
171    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
172        let index = index as c_int;
173        self.bind_index_with(index, |raw_statement| unsafe {
174            sqlite3_bind_int(*raw_statement, index, int);
175        })
176    }
177
178    pub fn column_int(&self, index: i32) -> Result<i32> {
179        let index = index as c_int;
180        let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
181        self.connection
182            .last_error()
183            .with_context(|| format!("Failed to read int at index {index}"))?;
184        Ok(result)
185    }
186
187    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
188        let index = index as c_int;
189        self.bind_index_with(index, |raw_statement| unsafe {
190            sqlite3_bind_int64(*raw_statement, index, int);
191        })
192    }
193
194    pub fn column_int64(&self, index: i32) -> Result<i64> {
195        let index = index as c_int;
196        let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
197        self.connection
198            .last_error()
199            .with_context(|| format!("Failed to read i64 at index {index}"))?;
200        Ok(result)
201    }
202
203    pub fn bind_null(&self, index: i32) -> Result<()> {
204        let index = index as c_int;
205        self.bind_index_with(index, |raw_statement| unsafe {
206            sqlite3_bind_null(*raw_statement, index);
207        })
208    }
209
210    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
211        let index = index as c_int;
212        let text_pointer = text.as_ptr() as *const _;
213        let len = text.len() as c_int;
214
215        self.bind_index_with(index, |raw_statement| unsafe {
216            sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
217        })
218    }
219
220    pub fn column_text(&mut self, index: i32) -> Result<&str> {
221        let index = index as c_int;
222        let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
223
224        self.connection
225            .last_error()
226            .with_context(|| format!("Failed to read text from column {index}"))?;
227        if pointer.is_null() {
228            return Ok("");
229        }
230        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
231        self.connection
232            .last_error()
233            .with_context(|| format!("Failed to read text length at {index}"))?;
234
235        let slice = unsafe { slice::from_raw_parts(pointer, len) };
236        Ok(str::from_utf8(slice)?)
237    }
238
239    pub fn bind<T: Bind>(&self, value: &T, index: i32) -> Result<i32> {
240        debug_assert!(index > 0);
241        value.bind(self, index)
242    }
243
244    pub fn column<T: Column>(&mut self) -> Result<T> {
245        Ok(T::column(self, 0)?.0)
246    }
247
248    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
249        let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
250        self.connection.last_error()?;
251        match result {
252            SQLITE_INTEGER => Ok(SqlType::Integer),
253            SQLITE_FLOAT => Ok(SqlType::Float),
254            SQLITE_TEXT => Ok(SqlType::Text),
255            SQLITE_BLOB => Ok(SqlType::Blob),
256            SQLITE_NULL => Ok(SqlType::Null),
257            _ => Err(anyhow!("Column type returned was incorrect ")),
258        }
259    }
260
261    pub fn with_bindings(&mut self, bindings: &impl Bind) -> Result<&mut Self> {
262        self.bind(bindings, 1)?;
263        Ok(self)
264    }
265
266    fn step(&mut self) -> Result<StepResult> {
267        unsafe {
268            match sqlite3_step(self.current_statement()) {
269                SQLITE_ROW => Ok(StepResult::Row),
270                SQLITE_DONE => {
271                    if self.current_statement >= self.raw_statements.len() - 1 {
272                        Ok(StepResult::Done)
273                    } else {
274                        self.current_statement += 1;
275                        self.step()
276                    }
277                }
278                SQLITE_MISUSE => Err(anyhow!("Statement step returned SQLITE_MISUSE")),
279                _other_error => {
280                    self.connection.last_error()?;
281                    unreachable!("Step returned error code and last error failed to catch it");
282                }
283            }
284        }
285    }
286
287    pub fn exec(&mut self) -> Result<()> {
288        fn logic(this: &mut Statement) -> Result<()> {
289            while this.step()? == StepResult::Row {}
290            Ok(())
291        }
292        let result = logic(self);
293        self.reset();
294        result
295    }
296
297    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
298        fn logic<R>(
299            this: &mut Statement,
300            mut callback: impl FnMut(&mut Statement) -> Result<R>,
301        ) -> Result<Vec<R>> {
302            let mut mapped_rows = Vec::new();
303            while this.step()? == StepResult::Row {
304                mapped_rows.push(callback(this)?);
305            }
306            Ok(mapped_rows)
307        }
308
309        let result = logic(self, callback);
310        self.reset();
311        result
312    }
313
314    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
315        self.map(|s| s.column::<R>())
316    }
317
318    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
319        fn logic<R>(
320            this: &mut Statement,
321            callback: impl FnOnce(&mut Statement) -> Result<R>,
322        ) -> Result<R> {
323            println!("{:?}", std::any::type_name::<R>());
324            if this.step()? != StepResult::Row {
325                return Err(anyhow!("single called with query that returns no rows."));
326            }
327            let result = callback(this)?;
328
329            if this.step()? != StepResult::Done {
330                return Err(anyhow!(
331                    "single called with a query that returns more than one row."
332                ));
333            }
334
335            Ok(result)
336        }
337        let result = logic(self, callback);
338        self.reset();
339        result
340    }
341
342    pub fn row<R: Column>(&mut self) -> Result<R> {
343        self.single(|this| this.column::<R>())
344    }
345
346    pub fn maybe<R>(
347        &mut self,
348        callback: impl FnOnce(&mut Statement) -> Result<R>,
349    ) -> Result<Option<R>> {
350        fn logic<R>(
351            this: &mut Statement,
352            callback: impl FnOnce(&mut Statement) -> Result<R>,
353        ) -> Result<Option<R>> {
354            if this.step().context("Failed on step call")? != StepResult::Row {
355                return Ok(None);
356            }
357
358            let result = callback(this)
359                .map(|r| Some(r))
360                .context("Failed to parse row result")?;
361
362            if this.step().context("Second step call")? != StepResult::Done {
363                return Err(anyhow!(
364                    "maybe called with a query that returns more than one row."
365                ));
366            }
367
368            Ok(result)
369        }
370        let result = logic(self, callback);
371        self.reset();
372        result
373    }
374
375    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
376        self.maybe(|this| this.column::<R>())
377    }
378}
379
380impl Drop for Statement<'_> {
381    fn drop(&mut self) {
382        unsafe {
383            for raw_statement in self.raw_statements.iter() {
384                sqlite3_finalize(*raw_statement);
385            }
386        }
387    }
388}
389
390#[cfg(test)]
391mod test {
392    use indoc::indoc;
393
394    use crate::{
395        connection::Connection,
396        statement::{Statement, StepResult},
397    };
398
399    #[test]
400    fn binding_multiple_statements_with_parameter_gaps() {
401        let connection =
402            Connection::open_memory(Some("binding_multiple_statements_with_parameter_gaps"));
403
404        connection
405            .exec(indoc! {"
406            CREATE TABLE test (
407                col INTEGER
408            )"})
409            .unwrap()()
410        .unwrap();
411
412        let statement = Statement::prepare(
413            &connection,
414            indoc! {"
415                INSERT INTO test(col) VALUES (?3);
416                SELECT * FROM test WHERE col = ?1"},
417        )
418        .unwrap();
419
420        statement
421            .bind_int(1, 1)
422            .expect("Could not bind parameter to first index");
423        statement
424            .bind_int(2, 2)
425            .expect("Could not bind parameter to second index");
426        statement
427            .bind_int(3, 3)
428            .expect("Could not bind parameter to third index");
429    }
430
431    #[test]
432    fn blob_round_trips() {
433        let connection1 = Connection::open_memory(Some("blob_round_trips"));
434        connection1
435            .exec(indoc! {"
436                CREATE TABLE blobs (
437                    data BLOB
438                )"})
439            .unwrap()()
440        .unwrap();
441
442        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
443
444        let mut write =
445            Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
446        write.bind_blob(1, blob).unwrap();
447        assert_eq!(write.step().unwrap(), StepResult::Done);
448
449        // Read the blob from the
450        let connection2 = Connection::open_memory(Some("blob_round_trips"));
451        let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
452        assert_eq!(read.step().unwrap(), StepResult::Row);
453        assert_eq!(read.column_blob(0).unwrap(), blob);
454        assert_eq!(read.step().unwrap(), StepResult::Done);
455
456        // Delete the added blob and verify its deleted on the other side
457        connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
458        let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
459        assert_eq!(read.step().unwrap(), StepResult::Done);
460    }
461
462    #[test]
463    pub fn maybe_returns_options() {
464        let connection = Connection::open_memory(Some("maybe_returns_options"));
465        connection
466            .exec(indoc! {"
467                CREATE TABLE texts (
468                    text TEXT
469                )"})
470            .unwrap()()
471        .unwrap();
472
473        assert!(connection
474            .select_row::<String>("SELECT text FROM texts")
475            .unwrap()()
476        .unwrap()
477        .is_none());
478
479        let text_to_insert = "This is a test";
480
481        connection
482            .exec_bound("INSERT INTO texts VALUES (?)")
483            .unwrap()(text_to_insert)
484        .unwrap();
485
486        assert_eq!(
487            connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
488            Some(text_to_insert.to_string())
489        );
490    }
491}