statement.rs

  1use std::ffi::{c_int, CString};
  2use std::marker::PhantomData;
  3use std::{slice, str};
  4
  5use anyhow::{anyhow, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    raw_statement: *mut sqlite3_stmt,
 13    connection: &'a Connection,
 14    phantom: PhantomData<sqlite3_stmt>,
 15}
 16
 17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 18pub enum StepResult {
 19    Row,
 20    Done,
 21    Misuse,
 22    Other(i32),
 23}
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 26pub enum SqlType {
 27    Text,
 28    Integer,
 29    Blob,
 30    Float,
 31    Null,
 32}
 33
 34impl<'a> Statement<'a> {
 35    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 36        let mut statement = Self {
 37            raw_statement: 0 as *mut _,
 38            connection,
 39            phantom: PhantomData,
 40        };
 41
 42        unsafe {
 43            sqlite3_prepare_v2(
 44                connection.sqlite3,
 45                CString::new(query.as_ref())?.as_ptr(),
 46                -1,
 47                &mut statement.raw_statement,
 48                0 as *mut _,
 49            );
 50
 51            connection
 52                .last_error()
 53                .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?;
 54        }
 55
 56        Ok(statement)
 57    }
 58
 59    pub fn reset(&mut self) {
 60        unsafe {
 61            sqlite3_reset(self.raw_statement);
 62        }
 63    }
 64
 65    pub fn parameter_count(&self) -> i32 {
 66        unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
 67    }
 68
 69    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
 70        // dbg!("bind blob", index);
 71        let index = index as c_int;
 72        let blob_pointer = blob.as_ptr() as *const _;
 73        let len = blob.len() as c_int;
 74        unsafe {
 75            sqlite3_bind_blob(
 76                self.raw_statement,
 77                index,
 78                blob_pointer,
 79                len,
 80                SQLITE_TRANSIENT(),
 81            );
 82        }
 83        self.connection.last_error()
 84    }
 85
 86    pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
 87        let index = index as c_int;
 88        let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
 89
 90        self.connection.last_error()?;
 91        if pointer.is_null() {
 92            return Ok(&[]);
 93        }
 94        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
 95        self.connection.last_error()?;
 96        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
 97    }
 98
 99    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
100        // dbg!("bind double", index);
101        let index = index as c_int;
102
103        unsafe {
104            sqlite3_bind_double(self.raw_statement, index, double);
105        }
106        self.connection.last_error()
107    }
108
109    pub fn column_double(&self, index: i32) -> Result<f64> {
110        let index = index as c_int;
111        let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
112        self.connection.last_error()?;
113        Ok(result)
114    }
115
116    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
117        // dbg!("bind int", index);
118        let index = index as c_int;
119
120        unsafe {
121            sqlite3_bind_int(self.raw_statement, index, int);
122        };
123        self.connection.last_error()
124    }
125
126    pub fn column_int(&self, index: i32) -> Result<i32> {
127        let index = index as c_int;
128        let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
129        self.connection.last_error()?;
130        Ok(result)
131    }
132
133    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
134        // dbg!("bind int64", index);
135        let index = index as c_int;
136        unsafe {
137            sqlite3_bind_int64(self.raw_statement, index, int);
138        }
139        self.connection.last_error()
140    }
141
142    pub fn column_int64(&self, index: i32) -> Result<i64> {
143        let index = index as c_int;
144        let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
145        self.connection.last_error()?;
146        Ok(result)
147    }
148
149    pub fn bind_null(&self, index: i32) -> Result<()> {
150        // dbg!("bind null", index);
151        let index = index as c_int;
152        unsafe {
153            sqlite3_bind_null(self.raw_statement, index);
154        }
155        self.connection.last_error()
156    }
157
158    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
159        // dbg!("bind text", index, text);
160        let index = index as c_int;
161        let text_pointer = text.as_ptr() as *const _;
162        let len = text.len() as c_int;
163        unsafe {
164            sqlite3_bind_text(
165                self.raw_statement,
166                index,
167                text_pointer,
168                len,
169                SQLITE_TRANSIENT(),
170            );
171        }
172        self.connection.last_error()
173    }
174
175    pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
176        let index = index as c_int;
177        let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
178
179        self.connection.last_error()?;
180        if pointer.is_null() {
181            return Ok("");
182        }
183        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
184        self.connection.last_error()?;
185
186        let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
187        Ok(str::from_utf8(slice)?)
188    }
189
190    pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
191        debug_assert!(index > 0);
192        value.bind(self, index)
193    }
194
195    pub fn column<T: Column>(&mut self) -> Result<T> {
196        let (result, _) = T::column(self, 0)?;
197        Ok(result)
198    }
199
200    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
201        let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
202        self.connection.last_error()?;
203        match result {
204            SQLITE_INTEGER => Ok(SqlType::Integer),
205            SQLITE_FLOAT => Ok(SqlType::Float),
206            SQLITE_TEXT => Ok(SqlType::Text),
207            SQLITE_BLOB => Ok(SqlType::Blob),
208            SQLITE_NULL => Ok(SqlType::Null),
209            _ => Err(anyhow!("Column type returned was incorrect ")),
210        }
211    }
212
213    pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
214        self.bind(bindings, 1)?;
215        Ok(self)
216    }
217
218    fn step(&mut self) -> Result<StepResult> {
219        unsafe {
220            match sqlite3_step(self.raw_statement) {
221                SQLITE_ROW => Ok(StepResult::Row),
222                SQLITE_DONE => Ok(StepResult::Done),
223                SQLITE_MISUSE => Ok(StepResult::Misuse),
224                other => self
225                    .connection
226                    .last_error()
227                    .map(|_| StepResult::Other(other)),
228            }
229        }
230    }
231
232    pub fn insert(&mut self) -> Result<i64> {
233        self.exec()?;
234        Ok(self.connection.last_insert_id())
235    }
236
237    pub fn exec(&mut self) -> Result<()> {
238        fn logic(this: &mut Statement) -> Result<()> {
239            while this.step()? == StepResult::Row {}
240            Ok(())
241        }
242        let result = logic(self);
243        self.reset();
244        result
245    }
246
247    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
248        fn logic<R>(
249            this: &mut Statement,
250            mut callback: impl FnMut(&mut Statement) -> Result<R>,
251        ) -> Result<Vec<R>> {
252            let mut mapped_rows = Vec::new();
253            while this.step()? == StepResult::Row {
254                mapped_rows.push(callback(this)?);
255            }
256            Ok(mapped_rows)
257        }
258
259        let result = logic(self, callback);
260        self.reset();
261        result
262    }
263
264    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
265        self.map(|s| s.column::<R>())
266    }
267
268    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
269        fn logic<R>(
270            this: &mut Statement,
271            callback: impl FnOnce(&mut Statement) -> Result<R>,
272        ) -> Result<R> {
273            if this.step()? != StepResult::Row {
274                return Err(anyhow!(
275                    "Single(Map) called with query that returns no rows."
276                ));
277            }
278            callback(this)
279        }
280        let result = logic(self, callback);
281        self.reset();
282        result
283    }
284
285    pub fn row<R: Column>(&mut self) -> Result<R> {
286        self.single(|this| this.column::<R>())
287    }
288
289    pub fn maybe<R>(
290        &mut self,
291        callback: impl FnOnce(&mut Statement) -> Result<R>,
292    ) -> Result<Option<R>> {
293        fn logic<R>(
294            this: &mut Statement,
295            callback: impl FnOnce(&mut Statement) -> Result<R>,
296        ) -> Result<Option<R>> {
297            if this.step()? != StepResult::Row {
298                return Ok(None);
299            }
300            callback(this).map(|r| Some(r))
301        }
302        let result = logic(self, callback);
303        self.reset();
304        result
305    }
306
307    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
308        self.maybe(|this| this.column::<R>())
309    }
310}
311
312impl<'a> Drop for Statement<'a> {
313    fn drop(&mut self) {
314        unsafe { sqlite3_finalize(self.raw_statement) };
315    }
316}
317
318#[cfg(test)]
319mod test {
320    use indoc::indoc;
321
322    use crate::{connection::Connection, statement::StepResult};
323
324    #[test]
325    fn blob_round_trips() {
326        let connection1 = Connection::open_memory("blob_round_trips");
327        connection1
328            .exec(indoc! {"
329                CREATE TABLE blobs (
330                data BLOB
331                );"})
332            .unwrap();
333
334        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
335
336        let mut write = connection1
337            .prepare("INSERT INTO blobs (data) VALUES (?);")
338            .unwrap();
339        write.bind_blob(1, blob).unwrap();
340        assert_eq!(write.step().unwrap(), StepResult::Done);
341
342        // Read the blob from the
343        let connection2 = Connection::open_memory("blob_round_trips");
344        let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
345        assert_eq!(read.step().unwrap(), StepResult::Row);
346        assert_eq!(read.column_blob(0).unwrap(), blob);
347        assert_eq!(read.step().unwrap(), StepResult::Done);
348
349        // Delete the added blob and verify its deleted on the other side
350        connection2.exec("DELETE FROM blobs;").unwrap();
351        let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
352        assert_eq!(read.step().unwrap(), StepResult::Done);
353    }
354
355    #[test]
356    pub fn maybe_returns_options() {
357        let connection = Connection::open_memory("maybe_returns_options");
358        connection
359            .exec(indoc! {"
360                CREATE TABLE texts (
361                    text TEXT 
362                );"})
363            .unwrap();
364
365        assert!(connection
366            .prepare("SELECT text FROM texts")
367            .unwrap()
368            .maybe_row::<String>()
369            .unwrap()
370            .is_none());
371
372        let text_to_insert = "This is a test";
373
374        connection
375            .prepare("INSERT INTO texts VALUES (?)")
376            .unwrap()
377            .with_bindings(text_to_insert)
378            .unwrap()
379            .exec()
380            .unwrap();
381
382        assert_eq!(
383            connection
384                .prepare("SELECT text FROM texts")
385                .unwrap()
386                .maybe_row::<String>()
387                .unwrap(),
388            Some(text_to_insert.to_string())
389        );
390    }
391}