statement.rs

  1use std::ffi::{c_int, CString};
  2use std::marker::PhantomData;
  3use std::{slice, str};
  4
  5use anyhow::{anyhow, Context, Result};
  6use libsqlite3_sys::*;
  7
  8use crate::bindable::{Bind, Column};
  9use crate::connection::Connection;
 10
 11pub struct Statement<'a> {
 12    raw_statement: *mut sqlite3_stmt,
 13    connection: &'a Connection,
 14    phantom: PhantomData<sqlite3_stmt>,
 15}
 16
 17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 18pub enum StepResult {
 19    Row,
 20    Done,
 21    Misuse,
 22    Other(i32),
 23}
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
 26pub enum SqlType {
 27    Text,
 28    Integer,
 29    Blob,
 30    Float,
 31    Null,
 32}
 33
 34impl<'a> Statement<'a> {
 35    pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
 36        let mut statement = Self {
 37            raw_statement: 0 as *mut _,
 38            connection,
 39            phantom: PhantomData,
 40        };
 41
 42        unsafe {
 43            sqlite3_prepare_v2(
 44                connection.sqlite3,
 45                CString::new(query.as_ref())?.as_ptr(),
 46                -1,
 47                &mut statement.raw_statement,
 48                0 as *mut _,
 49            );
 50
 51            connection.last_error().context("Prepare call failed.")?;
 52        }
 53
 54        Ok(statement)
 55    }
 56
 57    pub fn reset(&mut self) {
 58        unsafe {
 59            sqlite3_reset(self.raw_statement);
 60        }
 61    }
 62
 63    pub fn parameter_count(&self) -> i32 {
 64        unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
 65    }
 66
 67    pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
 68        let index = index as c_int;
 69        let blob_pointer = blob.as_ptr() as *const _;
 70        let len = blob.len() as c_int;
 71        unsafe {
 72            sqlite3_bind_blob(
 73                self.raw_statement,
 74                index,
 75                blob_pointer,
 76                len,
 77                SQLITE_TRANSIENT(),
 78            );
 79        }
 80        self.connection.last_error()
 81    }
 82
 83    pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
 84        let index = index as c_int;
 85        let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
 86
 87        self.connection.last_error()?;
 88        if pointer.is_null() {
 89            return Ok(&[]);
 90        }
 91        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
 92        self.connection.last_error()?;
 93        unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
 94    }
 95
 96    pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
 97        let index = index as c_int;
 98
 99        unsafe {
100            sqlite3_bind_double(self.raw_statement, index, double);
101        }
102        self.connection.last_error()
103    }
104
105    pub fn column_double(&self, index: i32) -> Result<f64> {
106        let index = index as c_int;
107        let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
108        self.connection.last_error()?;
109        Ok(result)
110    }
111
112    pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
113        let index = index as c_int;
114
115        unsafe {
116            sqlite3_bind_int(self.raw_statement, index, int);
117        }
118        self.connection.last_error()
119    }
120
121    pub fn column_int(&self, index: i32) -> Result<i32> {
122        let index = index as c_int;
123        let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
124        self.connection.last_error()?;
125        Ok(result)
126    }
127
128    pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
129        let index = index as c_int;
130        unsafe {
131            sqlite3_bind_int64(self.raw_statement, index, int);
132        }
133        self.connection.last_error()
134    }
135
136    pub fn column_int64(&self, index: i32) -> Result<i64> {
137        let index = index as c_int;
138        let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
139        self.connection.last_error()?;
140        Ok(result)
141    }
142
143    pub fn bind_null(&self, index: i32) -> Result<()> {
144        let index = index as c_int;
145        unsafe {
146            sqlite3_bind_null(self.raw_statement, index);
147        }
148        self.connection.last_error()
149    }
150
151    pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
152        let index = index as c_int;
153        let text_pointer = text.as_ptr() as *const _;
154        let len = text.len() as c_int;
155        unsafe {
156            sqlite3_bind_blob(
157                self.raw_statement,
158                index,
159                text_pointer,
160                len,
161                SQLITE_TRANSIENT(),
162            );
163        }
164        self.connection.last_error()
165    }
166
167    pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
168        let index = index as c_int;
169        let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
170
171        self.connection.last_error()?;
172        if pointer.is_null() {
173            return Ok("");
174        }
175        let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
176        self.connection.last_error()?;
177
178        let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
179        Ok(str::from_utf8(slice)?)
180    }
181
182    pub fn bind_value<T: Bind>(&self, value: T, idx: i32) -> Result<()> {
183        debug_assert!(idx > 0);
184        value.bind(self, idx)?;
185        Ok(())
186    }
187
188    pub fn column<T: Column>(&mut self) -> Result<T> {
189        let (result, _) = T::column(self, 0)?;
190        Ok(result)
191    }
192
193    pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
194        let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
195        self.connection.last_error()?;
196        match result {
197            SQLITE_INTEGER => Ok(SqlType::Integer),
198            SQLITE_FLOAT => Ok(SqlType::Float),
199            SQLITE_TEXT => Ok(SqlType::Text),
200            SQLITE_BLOB => Ok(SqlType::Blob),
201            SQLITE_NULL => Ok(SqlType::Null),
202            _ => Err(anyhow!("Column type returned was incorrect ")),
203        }
204    }
205
206    pub fn bind(&mut self, bindings: impl Bind) -> Result<&mut Self> {
207        self.bind_value(bindings, 1)?;
208        Ok(self)
209    }
210
211    fn step(&mut self) -> Result<StepResult> {
212        unsafe {
213            match sqlite3_step(self.raw_statement) {
214                SQLITE_ROW => Ok(StepResult::Row),
215                SQLITE_DONE => Ok(StepResult::Done),
216                SQLITE_MISUSE => Ok(StepResult::Misuse),
217                other => self
218                    .connection
219                    .last_error()
220                    .map(|_| StepResult::Other(other)),
221            }
222        }
223    }
224
225    pub fn insert(&mut self) -> Result<i64> {
226        self.exec()?;
227        Ok(self.connection.last_insert_id())
228    }
229
230    pub fn exec(&mut self) -> Result<()> {
231        fn logic(this: &mut Statement) -> Result<()> {
232            while this.step()? == StepResult::Row {}
233            Ok(())
234        }
235        let result = logic(self);
236        self.reset();
237        result
238    }
239
240    pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
241        fn logic<R>(
242            this: &mut Statement,
243            mut callback: impl FnMut(&mut Statement) -> Result<R>,
244        ) -> Result<Vec<R>> {
245            let mut mapped_rows = Vec::new();
246            while this.step()? == StepResult::Row {
247                mapped_rows.push(callback(this)?);
248            }
249            Ok(mapped_rows)
250        }
251
252        let result = logic(self, callback);
253        self.reset();
254        result
255    }
256
257    pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
258        self.map(|s| s.column::<R>())
259    }
260
261    pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
262        fn logic<R>(
263            this: &mut Statement,
264            callback: impl FnOnce(&mut Statement) -> Result<R>,
265        ) -> Result<R> {
266            if this.step()? != StepResult::Row {
267                return Err(anyhow!(
268                    "Single(Map) called with query that returns no rows."
269                ));
270            }
271            callback(this)
272        }
273        let result = logic(self, callback);
274        self.reset();
275        result
276    }
277
278    pub fn row<R: Column>(&mut self) -> Result<R> {
279        self.single(|this| this.column::<R>())
280    }
281
282    pub fn maybe<R>(
283        &mut self,
284        callback: impl FnOnce(&mut Statement) -> Result<R>,
285    ) -> Result<Option<R>> {
286        fn logic<R>(
287            this: &mut Statement,
288            callback: impl FnOnce(&mut Statement) -> Result<R>,
289        ) -> Result<Option<R>> {
290            if this.step()? != StepResult::Row {
291                return Ok(None);
292            }
293            callback(this).map(|r| Some(r))
294        }
295        let result = logic(self, callback);
296        self.reset();
297        result
298    }
299
300    pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
301        self.maybe(|this| this.column::<R>())
302    }
303}
304
305impl<'a> Drop for Statement<'a> {
306    fn drop(&mut self) {
307        unsafe {
308            sqlite3_finalize(self.raw_statement);
309            self.connection
310                .last_error()
311                .expect("sqlite3 finalize failed for statement :(");
312        };
313    }
314}
315
316#[cfg(test)]
317mod test {
318    use indoc::indoc;
319
320    use crate::{connection::Connection, statement::StepResult};
321
322    #[test]
323    fn blob_round_trips() {
324        let connection1 = Connection::open_memory("blob_round_trips");
325        connection1
326            .exec(indoc! {"
327            CREATE TABLE blobs (
328            data BLOB
329            );"})
330            .unwrap();
331
332        let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
333
334        let mut write = connection1
335            .prepare("INSERT INTO blobs (data) VALUES (?);")
336            .unwrap();
337        write.bind_blob(1, blob).unwrap();
338        assert_eq!(write.step().unwrap(), StepResult::Done);
339
340        // Read the blob from the
341        let connection2 = Connection::open_memory("blob_round_trips");
342        let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
343        assert_eq!(read.step().unwrap(), StepResult::Row);
344        assert_eq!(read.column_blob(0).unwrap(), blob);
345        assert_eq!(read.step().unwrap(), StepResult::Done);
346
347        // Delete the added blob and verify its deleted on the other side
348        connection2.exec("DELETE FROM blobs;").unwrap();
349        let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
350        assert_eq!(read.step().unwrap(), StepResult::Done);
351    }
352}