statement.rs

  1use std::ffi::{c_int, CStr, CString};
  2use std::marker::PhantomData;
  3use std::{ptr, slice, str};
  4
  5use anyhow::{anyhow, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    raw_statements: Vec<*mut sqlite3_stmt>,
 13    current_statement: usize,
 14    connection: &'a Connection,
 15    phantom: PhantomData<sqlite3_stmt>,
 16}
 17
 18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 19pub enum StepResult {
 20    Row,
 21    Done,
 22    Misuse,
 23    Other(i32),
 24}
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 27pub enum SqlType {
 28    Text,
 29    Integer,
 30    Blob,
 31    Float,
 32    Null,
 33}
 34
 35impl<'a> Statement<'a> {
 36    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 37        let mut statement = Self {
 38            raw_statements: Default::default(),
 39            current_statement: 0,
 40            connection,
 41            phantom: PhantomData,
 42        };
 43
 44        unsafe {
 45            let sql = CString::new(query.as_ref())?;
 46            let mut remaining_sql = sql.as_c_str();
 47            while {
 48                let remaining_sql_str = remaining_sql.to_str()?.trim();
 49                remaining_sql_str != ";" && !remaining_sql_str.is_empty()
 50            } {
 51                let mut raw_statement = 0 as *mut sqlite3_stmt;
 52                let mut remaining_sql_ptr = ptr::null();
 53                sqlite3_prepare_v2(
 54                    connection.sqlite3,
 55                    remaining_sql.as_ptr(),
 56                    -1,
 57                    &mut raw_statement,
 58                    &mut remaining_sql_ptr,
 59                );
 60                remaining_sql = CStr::from_ptr(remaining_sql_ptr);
 61                statement.raw_statements.push(raw_statement);
 62
 63                connection.last_error().with_context(|| {
 64                    format!("Prepare call failed for query:\n{}", query.as_ref())
 65                })?;
 66            }
 67        }
 68
 69        Ok(statement)
 70    }
 71
 72    fn current_statement(&self) -> *mut sqlite3_stmt {
 73        *self.raw_statements.get(self.current_statement).unwrap()
 74    }
 75
 76    pub fn reset(&mut self) {
 77        unsafe {
 78            for raw_statement in self.raw_statements.iter() {
 79                sqlite3_reset(*raw_statement);
 80            }
 81        }
 82        self.current_statement = 0;
 83    }
 84
 85    pub fn parameter_count(&self) -> i32 {
 86        unsafe {
 87            self.raw_statements
 88                .iter()
 89                .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
 90                .max()
 91                .unwrap_or(0)
 92        }
 93    }
 94
 95    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
 96        let index = index as c_int;
 97        let blob_pointer = blob.as_ptr() as *const _;
 98        let len = blob.len() as c_int;
 99        unsafe {
100            for raw_statement in self.raw_statements.iter() {
101                sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
102            }
103        }
104        self.connection.last_error()
105    }
106
107    pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
108        let index = index as c_int;
109        let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
110
111        self.connection.last_error()?;
112        if pointer.is_null() {
113            return Ok(&[]);
114        }
115        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
116        self.connection.last_error()?;
117        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
118    }
119
120    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
121        let index = index as c_int;
122
123        unsafe {
124            for raw_statement in self.raw_statements.iter() {
125                sqlite3_bind_double(*raw_statement, index, double);
126            }
127        }
128        self.connection.last_error()
129    }
130
131    pub fn column_double(&self, index: i32) -> Result<f64> {
132        let index = index as c_int;
133        let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
134        self.connection.last_error()?;
135        Ok(result)
136    }
137
138    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
139        let index = index as c_int;
140
141        unsafe {
142            for raw_statement in self.raw_statements.iter() {
143                sqlite3_bind_int(*raw_statement, index, int);
144            }
145        };
146        self.connection.last_error()
147    }
148
149    pub fn column_int(&self, index: i32) -> Result<i32> {
150        let index = index as c_int;
151        let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
152        self.connection.last_error()?;
153        Ok(result)
154    }
155
156    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
157        let index = index as c_int;
158        unsafe {
159            for raw_statement in self.raw_statements.iter() {
160                sqlite3_bind_int64(*raw_statement, index, int);
161            }
162        }
163        self.connection.last_error()
164    }
165
166    pub fn column_int64(&self, index: i32) -> Result<i64> {
167        let index = index as c_int;
168        let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
169        self.connection.last_error()?;
170        Ok(result)
171    }
172
173    pub fn bind_null(&self, index: i32) -> Result<()> {
174        let index = index as c_int;
175        unsafe {
176            for raw_statement in self.raw_statements.iter() {
177                sqlite3_bind_null(*raw_statement, index);
178            }
179        }
180        self.connection.last_error()
181    }
182
183    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
184        let index = index as c_int;
185        let text_pointer = text.as_ptr() as *const _;
186        let len = text.len() as c_int;
187        unsafe {
188            for raw_statement in self.raw_statements.iter() {
189                sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
190            }
191        }
192        self.connection.last_error()
193    }
194
195    pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
196        let index = index as c_int;
197        let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
198
199        self.connection.last_error()?;
200        if pointer.is_null() {
201            return Ok("");
202        }
203        let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
204        self.connection.last_error()?;
205
206        let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
207        Ok(str::from_utf8(slice)?)
208    }
209
210    pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
211        debug_assert!(index > 0);
212        value.bind(self, index)
213    }
214
215    pub fn column<T: Column>(&mut self) -> Result<T> {
216        let (result, _) = T::column(self, 0)?;
217        Ok(result)
218    }
219
220    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
221        let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
222        self.connection.last_error()?;
223        match result {
224            SQLITE_INTEGER => Ok(SqlType::Integer),
225            SQLITE_FLOAT => Ok(SqlType::Float),
226            SQLITE_TEXT => Ok(SqlType::Text),
227            SQLITE_BLOB => Ok(SqlType::Blob),
228            SQLITE_NULL => Ok(SqlType::Null),
229            _ => Err(anyhow!("Column type returned was incorrect ")),
230        }
231    }
232
233    pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
234        self.bind(bindings, 1)?;
235        Ok(self)
236    }
237
238    fn step(&mut self) -> Result<StepResult> {
239        unsafe {
240            match sqlite3_step(self.current_statement()) {
241                SQLITE_ROW => Ok(StepResult::Row),
242                SQLITE_DONE => {
243                    if self.current_statement >= self.raw_statements.len() - 1 {
244                        Ok(StepResult::Done)
245                    } else {
246                        self.current_statement += 1;
247                        self.step()
248                    }
249                }
250                SQLITE_MISUSE => Ok(StepResult::Misuse),
251                other => self
252                    .connection
253                    .last_error()
254                    .map(|_| StepResult::Other(other)),
255            }
256        }
257    }
258
259    pub fn exec(&mut self) -> Result<()> {
260        fn logic(this: &mut Statement) -> Result<()> {
261            while this.step()? == StepResult::Row {}
262            Ok(())
263        }
264        let result = logic(self);
265        self.reset();
266        result
267    }
268
269    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
270        fn logic<R>(
271            this: &mut Statement,
272            mut callback: impl FnMut(&mut Statement) -> Result<R>,
273        ) -> Result<Vec<R>> {
274            let mut mapped_rows = Vec::new();
275            while this.step()? == StepResult::Row {
276                mapped_rows.push(callback(this)?);
277            }
278            Ok(mapped_rows)
279        }
280
281        let result = logic(self, callback);
282        self.reset();
283        result
284    }
285
286    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
287        self.map(|s| s.column::<R>())
288    }
289
290    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
291        fn logic<R>(
292            this: &mut Statement,
293            callback: impl FnOnce(&mut Statement) -> Result<R>,
294        ) -> Result<R> {
295            if this.step()? != StepResult::Row {
296                return Err(anyhow!(
297                    "Single(Map) called with query that returns no rows."
298                ));
299            }
300            callback(this)
301        }
302        let result = logic(self, callback);
303        self.reset();
304        result
305    }
306
307    pub fn row<R: Column>(&mut self) -> Result<R> {
308        self.single(|this| this.column::<R>())
309    }
310
311    pub fn maybe<R>(
312        &mut self,
313        callback: impl FnOnce(&mut Statement) -> Result<R>,
314    ) -> Result<Option<R>> {
315        fn logic<R>(
316            this: &mut Statement,
317            callback: impl FnOnce(&mut Statement) -> Result<R>,
318        ) -> Result<Option<R>> {
319            if this.step()? != StepResult::Row {
320                return Ok(None);
321            }
322            callback(this).map(|r| Some(r))
323        }
324        let result = logic(self, callback);
325        self.reset();
326        result
327    }
328
329    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
330        self.maybe(|this| this.column::<R>())
331    }
332}
333
334impl<'a> Drop for Statement<'a> {
335    fn drop(&mut self) {
336        unsafe {
337            for raw_statement in self.raw_statements.iter() {
338                sqlite3_finalize(*raw_statement);
339            }
340        }
341    }
342}
343
344#[cfg(test)]
345mod test {
346    use indoc::indoc;
347
348    use crate::{
349        connection::Connection,
350        statement::{Statement, StepResult},
351    };
352
353    #[test]
354    fn blob_round_trips() {
355        let connection1 = Connection::open_memory(Some("blob_round_trips"));
356        connection1
357            .exec(indoc! {"
358                CREATE TABLE blobs (
359                    data BLOB
360                )"})
361            .unwrap()()
362        .unwrap();
363
364        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
365
366        let mut write =
367            Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
368        write.bind_blob(1, blob).unwrap();
369        assert_eq!(write.step().unwrap(), StepResult::Done);
370
371        // Read the blob from the
372        let connection2 = Connection::open_memory(Some("blob_round_trips"));
373        let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
374        assert_eq!(read.step().unwrap(), StepResult::Row);
375        assert_eq!(read.column_blob(0).unwrap(), blob);
376        assert_eq!(read.step().unwrap(), StepResult::Done);
377
378        // Delete the added blob and verify its deleted on the other side
379        connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
380        let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
381        assert_eq!(read.step().unwrap(), StepResult::Done);
382    }
383
384    #[test]
385    pub fn maybe_returns_options() {
386        let connection = Connection::open_memory(Some("maybe_returns_options"));
387        connection
388            .exec(indoc! {"
389                CREATE TABLE texts (
390                    text TEXT 
391                )"})
392            .unwrap()()
393        .unwrap();
394
395        assert!(connection
396            .select_row::<String>("SELECT text FROM texts")
397            .unwrap()()
398        .unwrap()
399        .is_none());
400
401        let text_to_insert = "This is a test";
402
403        connection
404            .exec_bound("INSERT INTO texts VALUES (?)")
405            .unwrap()(text_to_insert)
406        .unwrap();
407
408        assert_eq!(
409            connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
410            Some(text_to_insert.to_string())
411        );
412    }
413}