statement.rs

  1use std::ffi::{c_int, CStr, CString};
  2use std::marker::PhantomData;
  3use std::{ptr, slice, str};
  4
  5use anyhow::{anyhow, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    raw_statements: Vec<*mut sqlite3_stmt>,
 13    current_statement: usize,
 14    connection: &'a Connection,
 15    phantom: PhantomData<sqlite3_stmt>,
 16}
 17
 18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 19pub enum StepResult {
 20    Row,
 21    Done,
 22    Misuse,
 23    Other(i32),
 24}
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 27pub enum SqlType {
 28    Text,
 29    Integer,
 30    Blob,
 31    Float,
 32    Null,
 33}
 34
 35impl<'a> Statement<'a> {
 36    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 37        let mut statement = Self {
 38            raw_statements: Default::default(),
 39            current_statement: 0,
 40            connection,
 41            phantom: PhantomData,
 42        };
 43
 44        unsafe {
 45            let sql = CString::new(query.as_ref())?;
 46            let mut remaining_sql = sql.as_c_str();
 47            while {
 48                let remaining_sql_str = remaining_sql.to_str()?.trim();
 49                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 50            } {
 51                let mut raw_statement = 0 as *mut sqlite3_stmt;
 52                let mut remaining_sql_ptr = ptr::null();
 53                sqlite3_prepare_v2(
 54                    connection.sqlite3,
 55                    remaining_sql.as_ptr(),
 56                    -1,
 57                    &mut raw_statement,
 58                    &mut remaining_sql_ptr,
 59                );
 60                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
 61                statement.raw_statements.push(raw_statement);
 62            }
 63
 64            connection
 65                .last_error()
 66                .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?;
 67        }
 68
 69        Ok(statement)
 70    }
 71
 72    fn current_statement(&self) -> *mut sqlite3_stmt {
 73        *self.raw_statements.get(self.current_statement).unwrap()
 74    }
 75
 76    pub fn reset(&mut self) {
 77        unsafe {
 78            for raw_statement in self.raw_statements.iter() {
 79                sqlite3_reset(*raw_statement);
 80            }
 81        }
 82        self.current_statement = 0;
 83    }
 84
 85    pub fn parameter_count(&self) -> i32 {
 86        unsafe {
 87            self.raw_statements
 88                .iter()
 89                .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
 90                .max()
 91                .unwrap_or(0)
 92        }
 93    }
 94
 95    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
 96        let index = index as c_int;
 97        let blob_pointer = blob.as_ptr() as *const _;
 98        let len = blob.len() as c_int;
 99        unsafe {
100            for raw_statement in self.raw_statements.iter() {
101                sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
102            }
103        }
104        self.connection.last_error()
105    }
106
107    pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
108        let index = index as c_int;
109        let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
110
111        self.connection.last_error()?;
112        if pointer.is_null() {
113            return Ok(&[]);
114        }
115        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
116        self.connection.last_error()?;
117        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
118    }
119
120    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
121        let index = index as c_int;
122
123        unsafe {
124            for raw_statement in self.raw_statements.iter() {
125                sqlite3_bind_double(*raw_statement, index, double);
126            }
127        }
128        self.connection.last_error()
129    }
130
131    pub fn column_double(&self, index: i32) -> Result<f64> {
132        let index = index as c_int;
133        let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
134        self.connection.last_error()?;
135        Ok(result)
136    }
137
138    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
139        let index = index as c_int;
140
141        unsafe {
142            for raw_statement in self.raw_statements.iter() {
143                sqlite3_bind_int(*raw_statement, index, int);
144            }
145        };
146        self.connection.last_error()
147    }
148
149    pub fn column_int(&self, index: i32) -> Result<i32> {
150        let index = index as c_int;
151        let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
152        self.connection.last_error()?;
153        Ok(result)
154    }
155
156    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
157        let index = index as c_int;
158        unsafe {
159            for raw_statement in self.raw_statements.iter() {
160                sqlite3_bind_int64(*raw_statement, index, int);
161            }
162        }
163        self.connection.last_error()
164    }
165
166    pub fn column_int64(&self, index: i32) -> Result<i64> {
167        let index = index as c_int;
168        let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
169        self.connection.last_error()?;
170        Ok(result)
171    }
172
173    pub fn bind_null(&self, index: i32) -> Result<()> {
174        let index = index as c_int;
175        unsafe {
176            for raw_statement in self.raw_statements.iter() {
177                sqlite3_bind_null(*raw_statement, index);
178            }
179        }
180        self.connection.last_error()
181    }
182
183    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
184        let index = index as c_int;
185        let text_pointer = text.as_ptr() as *const _;
186        let len = text.len() as c_int;
187        unsafe {
188            for raw_statement in self.raw_statements.iter() {
189                sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
190            }
191        }
192        self.connection.last_error()
193    }
194
195    pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
196        let index = index as c_int;
197        let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
198
199        self.connection.last_error()?;
200        if pointer.is_null() {
201            return Ok("");
202        }
203        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
204        self.connection.last_error()?;
205
206        let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
207        Ok(str::from_utf8(slice)?)
208    }
209
210    pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
211        debug_assert!(index > 0);
212        value.bind(self, index)
213    }
214
215    pub fn column<T: Column>(&mut self) -> Result<T> {
216        let (result, _) = T::column(self, 0)?;
217        Ok(result)
218    }
219
220    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
221        let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
222        self.connection.last_error()?;
223        match result {
224            SQLITE_INTEGER => Ok(SqlType::Integer),
225            SQLITE_FLOAT => Ok(SqlType::Float),
226            SQLITE_TEXT => Ok(SqlType::Text),
227            SQLITE_BLOB => Ok(SqlType::Blob),
228            SQLITE_NULL => Ok(SqlType::Null),
229            _ => Err(anyhow!("Column type returned was incorrect ")),
230        }
231    }
232
233    pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
234        self.bind(bindings, 1)?;
235        Ok(self)
236    }
237
238    fn step(&mut self) -> Result<StepResult> {
239        unsafe {
240            match sqlite3_step(self.current_statement()) {
241                SQLITE_ROW => Ok(StepResult::Row),
242                SQLITE_DONE => {
243                    if self.current_statement >= self.raw_statements.len() - 1 {
244                        Ok(StepResult::Done)
245                    } else {
246                        self.current_statement += 1;
247                        self.step()
248                    }
249                }
250                SQLITE_MISUSE => Ok(StepResult::Misuse),
251                other => self
252                    .connection
253                    .last_error()
254                    .map(|_| StepResult::Other(other)),
255            }
256        }
257    }
258
259    pub fn insert(&mut self) -> Result<i64> {
260        self.exec()?;
261        Ok(self.connection.last_insert_id())
262    }
263
264    pub fn exec(&mut self) -> Result<()> {
265        fn logic(this: &mut Statement) -> Result<()> {
266            while this.step()? == StepResult::Row {}
267            Ok(())
268        }
269        let result = logic(self);
270        self.reset();
271        result
272    }
273
274    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
275        fn logic<R>(
276            this: &mut Statement,
277            mut callback: impl FnMut(&mut Statement) -> Result<R>,
278        ) -> Result<Vec<R>> {
279            let mut mapped_rows = Vec::new();
280            while this.step()? == StepResult::Row {
281                mapped_rows.push(callback(this)?);
282            }
283            Ok(mapped_rows)
284        }
285
286        let result = logic(self, callback);
287        self.reset();
288        result
289    }
290
291    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
292        self.map(|s| s.column::<R>())
293    }
294
295    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
296        fn logic<R>(
297            this: &mut Statement,
298            callback: impl FnOnce(&mut Statement) -> Result<R>,
299        ) -> Result<R> {
300            if this.step()? != StepResult::Row {
301                return Err(anyhow!(
302                    "Single(Map) called with query that returns no rows."
303                ));
304            }
305            callback(this)
306        }
307        let result = logic(self, callback);
308        self.reset();
309        result
310    }
311
312    pub fn row<R: Column>(&mut self) -> Result<R> {
313        self.single(|this| this.column::<R>())
314    }
315
316    pub fn maybe<R>(
317        &mut self,
318        callback: impl FnOnce(&mut Statement) -> Result<R>,
319    ) -> Result<Option<R>> {
320        fn logic<R>(
321            this: &mut Statement,
322            callback: impl FnOnce(&mut Statement) -> Result<R>,
323        ) -> Result<Option<R>> {
324            if this.step()? != StepResult::Row {
325                return Ok(None);
326            }
327            callback(this).map(|r| Some(r))
328        }
329        let result = logic(self, callback);
330        self.reset();
331        result
332    }
333
334    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
335        self.maybe(|this| this.column::<R>())
336    }
337}
338
339impl<'a> Drop for Statement<'a> {
340    fn drop(&mut self) {
341        unsafe {
342            for raw_statement in self.raw_statements.iter() {
343                sqlite3_finalize(*raw_statement);
344            }
345        }
346    }
347}
348
349#[cfg(test)]
350mod test {
351    use indoc::indoc;
352
353    use crate::{
354        connection::Connection,
355        statement::{Statement, StepResult},
356    };
357
358    #[test]
359    fn blob_round_trips() {
360        let connection1 = Connection::open_memory("blob_round_trips");
361        connection1
362            .exec(indoc! {"
363                CREATE TABLE blobs (
364                    data BLOB
365                )"})
366            .unwrap()()
367        .unwrap();
368
369        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
370
371        let mut write =
372            Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
373        write.bind_blob(1, blob).unwrap();
374        assert_eq!(write.step().unwrap(), StepResult::Done);
375
376        // Read the blob from the
377        let connection2 = Connection::open_memory("blob_round_trips");
378        let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
379        assert_eq!(read.step().unwrap(), StepResult::Row);
380        assert_eq!(read.column_blob(0).unwrap(), blob);
381        assert_eq!(read.step().unwrap(), StepResult::Done);
382
383        // Delete the added blob and verify its deleted on the other side
384        connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
385        let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
386        assert_eq!(read.step().unwrap(), StepResult::Done);
387    }
388
389    #[test]
390    pub fn maybe_returns_options() {
391        let connection = Connection::open_memory("maybe_returns_options");
392        connection
393            .exec(indoc! {"
394                CREATE TABLE texts (
395                    text TEXT 
396                )"})
397            .unwrap()()
398        .unwrap();
399
400        assert!(connection
401            .select_row::<String>("SELECT text FROM texts")
402            .unwrap()()
403        .unwrap()
404        .is_none());
405
406        let text_to_insert = "This is a test";
407
408        connection
409            .exec_bound("INSERT INTO texts VALUES (?)")
410            .unwrap()(text_to_insert)
411        .unwrap();
412
413        assert_eq!(
414            connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
415            Some(text_to_insert.to_string())
416        );
417    }
418}