statement.rs

  1use std::ffi::{c_int, CString};
  2use std::marker::PhantomData;
  3use std::{slice, str};
  4
  5use anyhow::{anyhow, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::{error_to_result, Connection};
 10
 11pub struct Statement<'a> {
 12    raw_statement: *mut sqlite3_stmt,
 13    connection: &'a Connection,
 14    phantom: PhantomData<sqlite3_stmt>,
 15}
 16
 17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 18pub enum StepResult {
 19    Row,
 20    Done,
 21    Misuse,
 22    Other(i32),
 23}
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 26pub enum SqlType {
 27    Text,
 28    Integer,
 29    Blob,
 30    Float,
 31    Null,
 32}
 33
 34impl<'a> Statement<'a> {
 35    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 36        let mut statement = Self {
 37            raw_statement: 0 as *mut _,
 38            connection,
 39            phantom: PhantomData,
 40        };
 41
 42        unsafe {
 43            sqlite3_prepare_v2(
 44                connection.sqlite3,
 45                CString::new(query.as_ref())?.as_ptr(),
 46                -1,
 47                &mut statement.raw_statement,
 48                0 as *mut _,
 49            );
 50
 51            connection.last_error().context("Prepare call failed.")?;
 52        }
 53
 54        Ok(statement)
 55    }
 56
 57    pub fn reset(&mut self) {
 58        unsafe {
 59            sqlite3_reset(self.raw_statement);
 60        }
 61    }
 62
 63    pub fn parameter_count(&self) -> i32 {
 64        unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
 65    }
 66
 67    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
 68        // dbg!("bind blob", index);
 69        let index = index as c_int;
 70        let blob_pointer = blob.as_ptr() as *const _;
 71        let len = blob.len() as c_int;
 72        unsafe {
 73            sqlite3_bind_blob(
 74                self.raw_statement,
 75                index,
 76                blob_pointer,
 77                len,
 78                SQLITE_TRANSIENT(),
 79            );
 80        }
 81        self.connection.last_error()
 82    }
 83
 84    pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
 85        let index = index as c_int;
 86        let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
 87
 88        self.connection.last_error()?;
 89        if pointer.is_null() {
 90            return Ok(&[]);
 91        }
 92        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
 93        self.connection.last_error()?;
 94        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
 95    }
 96
 97    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
 98        // dbg!("bind double", index);
 99        let index = index as c_int;
100
101        unsafe {
102            sqlite3_bind_double(self.raw_statement, index, double);
103        }
104        self.connection.last_error()
105    }
106
107    pub fn column_double(&self, index: i32) -> Result<f64> {
108        let index = index as c_int;
109        let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
110        self.connection.last_error()?;
111        Ok(result)
112    }
113
114    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
115        // dbg!("bind int", index);
116        let index = index as c_int;
117
118        unsafe {
119            sqlite3_bind_int(self.raw_statement, index, int);
120        };
121        self.connection.last_error()
122    }
123
124    pub fn column_int(&self, index: i32) -> Result<i32> {
125        let index = index as c_int;
126        let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
127        self.connection.last_error()?;
128        Ok(result)
129    }
130
131    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
132        // dbg!("bind int64", index);
133        let index = index as c_int;
134        unsafe {
135            sqlite3_bind_int64(self.raw_statement, index, int);
136        }
137        self.connection.last_error()
138    }
139
140    pub fn column_int64(&self, index: i32) -> Result<i64> {
141        let index = index as c_int;
142        let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
143        self.connection.last_error()?;
144        Ok(result)
145    }
146
147    pub fn bind_null(&self, index: i32) -> Result<()> {
148        // dbg!("bind null", index);
149        let index = index as c_int;
150        unsafe {
151            sqlite3_bind_null(self.raw_statement, index);
152        }
153        self.connection.last_error()
154    }
155
156    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
157        // dbg!("bind text", index, text);
158        let index = index as c_int;
159        let text_pointer = text.as_ptr() as *const _;
160        let len = text.len() as c_int;
161        unsafe {
162            sqlite3_bind_text(
163                self.raw_statement,
164                index,
165                text_pointer,
166                len,
167                SQLITE_TRANSIENT(),
168            );
169        }
170        self.connection.last_error()
171    }
172
173    pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
174        let index = index as c_int;
175        let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
176
177        self.connection.last_error()?;
178        if pointer.is_null() {
179            return Ok("");
180        }
181        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
182        self.connection.last_error()?;
183
184        let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
185        Ok(str::from_utf8(slice)?)
186    }
187
188    pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
189        debug_assert!(index > 0);
190        value.bind(self, index)
191    }
192
193    pub fn column<T: Column>(&mut self) -> Result<T> {
194        let (result, _) = T::column(self, 0)?;
195        Ok(result)
196    }
197
198    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
199        let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
200        self.connection.last_error()?;
201        match result {
202            SQLITE_INTEGER => Ok(SqlType::Integer),
203            SQLITE_FLOAT => Ok(SqlType::Float),
204            SQLITE_TEXT => Ok(SqlType::Text),
205            SQLITE_BLOB => Ok(SqlType::Blob),
206            SQLITE_NULL => Ok(SqlType::Null),
207            _ => Err(anyhow!("Column type returned was incorrect ")),
208        }
209    }
210
211    pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
212        self.bind(bindings, 1)?;
213        Ok(self)
214    }
215
216    fn step(&mut self) -> Result<StepResult> {
217        unsafe {
218            match sqlite3_step(self.raw_statement) {
219                SQLITE_ROW => Ok(StepResult::Row),
220                SQLITE_DONE => Ok(StepResult::Done),
221                SQLITE_MISUSE => Ok(StepResult::Misuse),
222                other => self
223                    .connection
224                    .last_error()
225                    .map(|_| StepResult::Other(other)),
226            }
227        }
228    }
229
230    pub fn insert(&mut self) -> Result<i64> {
231        self.exec()?;
232        Ok(self.connection.last_insert_id())
233    }
234
235    pub fn exec(&mut self) -> Result<()> {
236        fn logic(this: &mut Statement) -> Result<()> {
237            while this.step()? == StepResult::Row {}
238            Ok(())
239        }
240        let result = logic(self);
241        self.reset();
242        result
243    }
244
245    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
246        fn logic<R>(
247            this: &mut Statement,
248            mut callback: impl FnMut(&mut Statement) -> Result<R>,
249        ) -> Result<Vec<R>> {
250            let mut mapped_rows = Vec::new();
251            while this.step()? == StepResult::Row {
252                mapped_rows.push(callback(this)?);
253            }
254            Ok(mapped_rows)
255        }
256
257        let result = logic(self, callback);
258        self.reset();
259        result
260    }
261
262    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
263        self.map(|s| s.column::<R>())
264    }
265
266    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
267        fn logic<R>(
268            this: &mut Statement,
269            callback: impl FnOnce(&mut Statement) -> Result<R>,
270        ) -> Result<R> {
271            if this.step()? != StepResult::Row {
272                return Err(anyhow!(
273                    "Single(Map) called with query that returns no rows."
274                ));
275            }
276            callback(this)
277        }
278        let result = logic(self, callback);
279        self.reset();
280        result
281    }
282
283    pub fn row<R: Column>(&mut self) -> Result<R> {
284        self.single(|this| this.column::<R>())
285    }
286
287    pub fn maybe<R>(
288        &mut self,
289        callback: impl FnOnce(&mut Statement) -> Result<R>,
290    ) -> Result<Option<R>> {
291        fn logic<R>(
292            this: &mut Statement,
293            callback: impl FnOnce(&mut Statement) -> Result<R>,
294        ) -> Result<Option<R>> {
295            if this.step()? != StepResult::Row {
296                return Ok(None);
297            }
298            callback(this).map(|r| Some(r))
299        }
300        let result = logic(self, callback);
301        self.reset();
302        result
303    }
304
305    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
306        self.maybe(|this| this.column::<R>())
307    }
308}
309
310impl<'a> Drop for Statement<'a> {
311    fn drop(&mut self) {
312        unsafe {
313            let error = sqlite3_finalize(self.raw_statement);
314            error_to_result(error).expect("failed error");
315        };
316    }
317}
318
319#[cfg(test)]
320mod test {
321    use indoc::indoc;
322
323    use crate::{connection::Connection, statement::StepResult};
324
325    #[test]
326    fn blob_round_trips() {
327        let connection1 = Connection::open_memory("blob_round_trips");
328        connection1
329            .exec(indoc! {"
330            CREATE TABLE blobs (
331            data BLOB
332            );"})
333            .unwrap();
334
335        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
336
337        let mut write = connection1
338            .prepare("INSERT INTO blobs (data) VALUES (?);")
339            .unwrap();
340        write.bind_blob(1, blob).unwrap();
341        assert_eq!(write.step().unwrap(), StepResult::Done);
342
343        // Read the blob from the
344        let connection2 = Connection::open_memory("blob_round_trips");
345        let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
346        assert_eq!(read.step().unwrap(), StepResult::Row);
347        assert_eq!(read.column_blob(0).unwrap(), blob);
348        assert_eq!(read.step().unwrap(), StepResult::Done);
349
350        // Delete the added blob and verify its deleted on the other side
351        connection2.exec("DELETE FROM blobs;").unwrap();
352        let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
353        assert_eq!(read.step().unwrap(), StepResult::Done);
354    }
355}