1use std::ffi::{c_int, CStr, CString};
2use std::marker::PhantomData;
3use std::{ptr, slice, str};
4
5use anyhow::{anyhow, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 raw_statements: Vec<*mut sqlite3_stmt>,
13 current_statement: usize,
14 connection: &'a Connection,
15 phantom: PhantomData<sqlite3_stmt>,
16}
17
18#[derive(Clone, Copy, PartialEq, Eq, Debug)]
19pub enum StepResult {
20 Row,
21 Done,
22 Misuse,
23 Other(i32),
24}
25
26#[derive(Clone, Copy, PartialEq, Eq, Debug)]
27pub enum SqlType {
28 Text,
29 Integer,
30 Blob,
31 Float,
32 Null,
33}
34
35impl<'a> Statement<'a> {
36 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
37 let mut statement = Self {
38 raw_statements: Default::default(),
39 current_statement: 0,
40 connection,
41 phantom: PhantomData,
42 };
43
44 unsafe {
45 let sql = CString::new(query.as_ref())?;
46 let mut remaining_sql = sql.as_c_str();
47 while {
48 let remaining_sql_str = remaining_sql.to_str()?.trim();
49 remaining_sql_str != ";" && !remaining_sql_str.is_empty()
50 } {
51 let mut raw_statement = 0 as *mut sqlite3_stmt;
52 let mut remaining_sql_ptr = ptr::null();
53 sqlite3_prepare_v2(
54 connection.sqlite3,
55 remaining_sql.as_ptr(),
56 -1,
57 &mut raw_statement,
58 &mut remaining_sql_ptr,
59 );
60 remaining_sql = CStr::from_ptr(remaining_sql_ptr);
61 statement.raw_statements.push(raw_statement);
62
63 connection.last_error().with_context(|| {
64 format!("Prepare call failed for query:\n{}", query.as_ref())
65 })?;
66 }
67 }
68
69 Ok(statement)
70 }
71
72 fn current_statement(&self) -> *mut sqlite3_stmt {
73 *self.raw_statements.get(self.current_statement).unwrap()
74 }
75
76 pub fn reset(&mut self) {
77 unsafe {
78 for raw_statement in self.raw_statements.iter() {
79 sqlite3_reset(*raw_statement);
80 }
81 }
82 self.current_statement = 0;
83 }
84
85 pub fn parameter_count(&self) -> i32 {
86 unsafe {
87 self.raw_statements
88 .iter()
89 .map(|raw_statement| sqlite3_bind_parameter_count(*raw_statement))
90 .max()
91 .unwrap_or(0)
92 }
93 }
94
95 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
96 let index = index as c_int;
97 let blob_pointer = blob.as_ptr() as *const _;
98 let len = blob.len() as c_int;
99 unsafe {
100 for raw_statement in self.raw_statements.iter() {
101 sqlite3_bind_blob(*raw_statement, index, blob_pointer, len, SQLITE_TRANSIENT());
102 }
103 }
104 self.connection.last_error()
105 }
106
107 pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
108 let index = index as c_int;
109 let pointer = unsafe { sqlite3_column_blob(self.current_statement(), index) };
110
111 self.connection.last_error()?;
112 if pointer.is_null() {
113 return Ok(&[]);
114 }
115 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
116 self.connection.last_error()?;
117 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
118 }
119
120 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
121 let index = index as c_int;
122
123 unsafe {
124 for raw_statement in self.raw_statements.iter() {
125 sqlite3_bind_double(*raw_statement, index, double);
126 }
127 }
128 self.connection.last_error()
129 }
130
131 pub fn column_double(&self, index: i32) -> Result<f64> {
132 let index = index as c_int;
133 let result = unsafe { sqlite3_column_double(self.current_statement(), index) };
134 self.connection.last_error()?;
135 Ok(result)
136 }
137
138 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
139 let index = index as c_int;
140
141 unsafe {
142 for raw_statement in self.raw_statements.iter() {
143 sqlite3_bind_int(*raw_statement, index, int);
144 }
145 };
146 self.connection.last_error()
147 }
148
149 pub fn column_int(&self, index: i32) -> Result<i32> {
150 let index = index as c_int;
151 let result = unsafe { sqlite3_column_int(self.current_statement(), index) };
152 self.connection.last_error()?;
153 Ok(result)
154 }
155
156 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
157 let index = index as c_int;
158 unsafe {
159 for raw_statement in self.raw_statements.iter() {
160 sqlite3_bind_int64(*raw_statement, index, int);
161 }
162 }
163 self.connection.last_error()
164 }
165
166 pub fn column_int64(&self, index: i32) -> Result<i64> {
167 let index = index as c_int;
168 let result = unsafe { sqlite3_column_int64(self.current_statement(), index) };
169 self.connection.last_error()?;
170 Ok(result)
171 }
172
173 pub fn bind_null(&self, index: i32) -> Result<()> {
174 let index = index as c_int;
175 unsafe {
176 for raw_statement in self.raw_statements.iter() {
177 sqlite3_bind_null(*raw_statement, index);
178 }
179 }
180 self.connection.last_error()
181 }
182
183 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
184 let index = index as c_int;
185 let text_pointer = text.as_ptr() as *const _;
186 let len = text.len() as c_int;
187 unsafe {
188 for raw_statement in self.raw_statements.iter() {
189 sqlite3_bind_text(*raw_statement, index, text_pointer, len, SQLITE_TRANSIENT());
190 }
191 }
192 self.connection.last_error()
193 }
194
195 pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
196 let index = index as c_int;
197 let pointer = unsafe { sqlite3_column_text(self.current_statement(), index) };
198
199 self.connection.last_error()?;
200 if pointer.is_null() {
201 return Ok("");
202 }
203 let len = unsafe { sqlite3_column_bytes(self.current_statement(), index) as usize };
204 self.connection.last_error()?;
205
206 let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
207 Ok(str::from_utf8(slice)?)
208 }
209
210 pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
211 debug_assert!(index > 0);
212 value.bind(self, index)
213 }
214
215 pub fn column<T: Column>(&mut self) -> Result<T> {
216 let (result, _) = T::column(self, 0)?;
217 Ok(result)
218 }
219
220 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
221 let result = unsafe { sqlite3_column_type(self.current_statement(), index) };
222 self.connection.last_error()?;
223 match result {
224 SQLITE_INTEGER => Ok(SqlType::Integer),
225 SQLITE_FLOAT => Ok(SqlType::Float),
226 SQLITE_TEXT => Ok(SqlType::Text),
227 SQLITE_BLOB => Ok(SqlType::Blob),
228 SQLITE_NULL => Ok(SqlType::Null),
229 _ => Err(anyhow!("Column type returned was incorrect ")),
230 }
231 }
232
233 pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
234 self.bind(bindings, 1)?;
235 Ok(self)
236 }
237
238 fn step(&mut self) -> Result<StepResult> {
239 unsafe {
240 match sqlite3_step(self.current_statement()) {
241 SQLITE_ROW => Ok(StepResult::Row),
242 SQLITE_DONE => {
243 if self.current_statement >= self.raw_statements.len() - 1 {
244 Ok(StepResult::Done)
245 } else {
246 self.current_statement += 1;
247 self.step()
248 }
249 }
250 SQLITE_MISUSE => Ok(StepResult::Misuse),
251 other => self
252 .connection
253 .last_error()
254 .map(|_| StepResult::Other(other)),
255 }
256 }
257 }
258
259 pub fn exec(&mut self) -> Result<()> {
260 fn logic(this: &mut Statement) -> Result<()> {
261 while this.step()? == StepResult::Row {}
262 Ok(())
263 }
264 let result = logic(self);
265 self.reset();
266 result
267 }
268
269 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
270 fn logic<R>(
271 this: &mut Statement,
272 mut callback: impl FnMut(&mut Statement) -> Result<R>,
273 ) -> Result<Vec<R>> {
274 let mut mapped_rows = Vec::new();
275 while this.step()? == StepResult::Row {
276 mapped_rows.push(callback(this)?);
277 }
278 Ok(mapped_rows)
279 }
280
281 let result = logic(self, callback);
282 self.reset();
283 result
284 }
285
286 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
287 self.map(|s| s.column::<R>())
288 }
289
290 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
291 fn logic<R>(
292 this: &mut Statement,
293 callback: impl FnOnce(&mut Statement) -> Result<R>,
294 ) -> Result<R> {
295 if this.step()? != StepResult::Row {
296 return Err(anyhow!(
297 "Single(Map) called with query that returns no rows."
298 ));
299 }
300 callback(this)
301 }
302 let result = logic(self, callback);
303 self.reset();
304 result
305 }
306
307 pub fn row<R: Column>(&mut self) -> Result<R> {
308 self.single(|this| this.column::<R>())
309 }
310
311 pub fn maybe<R>(
312 &mut self,
313 callback: impl FnOnce(&mut Statement) -> Result<R>,
314 ) -> Result<Option<R>> {
315 fn logic<R>(
316 this: &mut Statement,
317 callback: impl FnOnce(&mut Statement) -> Result<R>,
318 ) -> Result<Option<R>> {
319 if this.step()? != StepResult::Row {
320 return Ok(None);
321 }
322 callback(this).map(|r| Some(r))
323 }
324 let result = logic(self, callback);
325 self.reset();
326 result
327 }
328
329 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
330 self.maybe(|this| this.column::<R>())
331 }
332}
333
334impl<'a> Drop for Statement<'a> {
335 fn drop(&mut self) {
336 unsafe {
337 for raw_statement in self.raw_statements.iter() {
338 sqlite3_finalize(*raw_statement);
339 }
340 }
341 }
342}
343
344#[cfg(test)]
345mod test {
346 use indoc::indoc;
347
348 use crate::{
349 connection::Connection,
350 statement::{Statement, StepResult},
351 };
352
353 #[test]
354 fn blob_round_trips() {
355 let connection1 = Connection::open_memory(Some("blob_round_trips"));
356 connection1
357 .exec(indoc! {"
358 CREATE TABLE blobs (
359 data BLOB
360 )"})
361 .unwrap()()
362 .unwrap();
363
364 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
365
366 let mut write =
367 Statement::prepare(&connection1, "INSERT INTO blobs (data) VALUES (?)").unwrap();
368 write.bind_blob(1, blob).unwrap();
369 assert_eq!(write.step().unwrap(), StepResult::Done);
370
371 // Read the blob from the
372 let connection2 = Connection::open_memory(Some("blob_round_trips"));
373 let mut read = Statement::prepare(&connection2, "SELECT * FROM blobs").unwrap();
374 assert_eq!(read.step().unwrap(), StepResult::Row);
375 assert_eq!(read.column_blob(0).unwrap(), blob);
376 assert_eq!(read.step().unwrap(), StepResult::Done);
377
378 // Delete the added blob and verify its deleted on the other side
379 connection2.exec("DELETE FROM blobs").unwrap()().unwrap();
380 let mut read = Statement::prepare(&connection1, "SELECT * FROM blobs").unwrap();
381 assert_eq!(read.step().unwrap(), StepResult::Done);
382 }
383
384 #[test]
385 pub fn maybe_returns_options() {
386 let connection = Connection::open_memory(Some("maybe_returns_options"));
387 connection
388 .exec(indoc! {"
389 CREATE TABLE texts (
390 text TEXT
391 )"})
392 .unwrap()()
393 .unwrap();
394
395 assert!(connection
396 .select_row::<String>("SELECT text FROM texts")
397 .unwrap()()
398 .unwrap()
399 .is_none());
400
401 let text_to_insert = "This is a test";
402
403 connection
404 .exec_bound("INSERT INTO texts VALUES (?)")
405 .unwrap()(text_to_insert)
406 .unwrap();
407
408 assert_eq!(
409 connection.select_row("SELECT text FROM texts").unwrap()().unwrap(),
410 Some(text_to_insert.to_string())
411 );
412 }
413}