1use std::ffi::{c_int, CString};
2use std::marker::PhantomData;
3use std::{slice, str};
4
5use anyhow::{anyhow, Context, Result};
6use libsqlite3_sys::*;
7
8use crate::bindable::{Bind, Column};
9use crate::connection::Connection;
10
11pub struct Statement<'a> {
12 raw_statement: *mut sqlite3_stmt,
13 connection: &'a Connection,
14 phantom: PhantomData<sqlite3_stmt>,
15}
16
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18pub enum StepResult {
19 Row,
20 Done,
21 Misuse,
22 Other(i32),
23}
24
25#[derive(Clone, Copy, PartialEq, Eq, Debug)]
26pub enum SqlType {
27 Text,
28 Integer,
29 Blob,
30 Float,
31 Null,
32}
33
34impl<'a> Statement<'a> {
35 pub fn prepare<T: AsRef<str>>(connection: &'a Connection, query: T) -> Result<Self> {
36 let mut statement = Self {
37 raw_statement: 0 as *mut _,
38 connection,
39 phantom: PhantomData,
40 };
41
42 unsafe {
43 sqlite3_prepare_v2(
44 connection.sqlite3,
45 CString::new(query.as_ref())?.as_ptr(),
46 -1,
47 &mut statement.raw_statement,
48 0 as *mut _,
49 );
50
51 connection
52 .last_error()
53 .with_context(|| format!("Prepare call failed for query:\n{}", query.as_ref()))?;
54 }
55
56 Ok(statement)
57 }
58
59 pub fn reset(&mut self) {
60 unsafe {
61 sqlite3_reset(self.raw_statement);
62 }
63 }
64
65 pub fn parameter_count(&self) -> i32 {
66 unsafe { sqlite3_bind_parameter_count(self.raw_statement) }
67 }
68
69 pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> {
70 // dbg!("bind blob", index);
71 let index = index as c_int;
72 let blob_pointer = blob.as_ptr() as *const _;
73 let len = blob.len() as c_int;
74 unsafe {
75 sqlite3_bind_blob(
76 self.raw_statement,
77 index,
78 blob_pointer,
79 len,
80 SQLITE_TRANSIENT(),
81 );
82 }
83 self.connection.last_error()
84 }
85
86 pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> {
87 let index = index as c_int;
88 let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) };
89
90 self.connection.last_error()?;
91 if pointer.is_null() {
92 return Ok(&[]);
93 }
94 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
95 self.connection.last_error()?;
96 unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) }
97 }
98
99 pub fn bind_double(&self, index: i32, double: f64) -> Result<()> {
100 // dbg!("bind double", index);
101 let index = index as c_int;
102
103 unsafe {
104 sqlite3_bind_double(self.raw_statement, index, double);
105 }
106 self.connection.last_error()
107 }
108
109 pub fn column_double(&self, index: i32) -> Result<f64> {
110 let index = index as c_int;
111 let result = unsafe { sqlite3_column_double(self.raw_statement, index) };
112 self.connection.last_error()?;
113 Ok(result)
114 }
115
116 pub fn bind_int(&self, index: i32, int: i32) -> Result<()> {
117 // dbg!("bind int", index);
118 let index = index as c_int;
119
120 unsafe {
121 sqlite3_bind_int(self.raw_statement, index, int);
122 };
123 self.connection.last_error()
124 }
125
126 pub fn column_int(&self, index: i32) -> Result<i32> {
127 let index = index as c_int;
128 let result = unsafe { sqlite3_column_int(self.raw_statement, index) };
129 self.connection.last_error()?;
130 Ok(result)
131 }
132
133 pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> {
134 // dbg!("bind int64", index);
135 let index = index as c_int;
136 unsafe {
137 sqlite3_bind_int64(self.raw_statement, index, int);
138 }
139 self.connection.last_error()
140 }
141
142 pub fn column_int64(&self, index: i32) -> Result<i64> {
143 let index = index as c_int;
144 let result = unsafe { sqlite3_column_int64(self.raw_statement, index) };
145 self.connection.last_error()?;
146 Ok(result)
147 }
148
149 pub fn bind_null(&self, index: i32) -> Result<()> {
150 // dbg!("bind null", index);
151 let index = index as c_int;
152 unsafe {
153 sqlite3_bind_null(self.raw_statement, index);
154 }
155 self.connection.last_error()
156 }
157
158 pub fn bind_text(&self, index: i32, text: &str) -> Result<()> {
159 // dbg!("bind text", index, text);
160 let index = index as c_int;
161 let text_pointer = text.as_ptr() as *const _;
162 let len = text.len() as c_int;
163 unsafe {
164 sqlite3_bind_text(
165 self.raw_statement,
166 index,
167 text_pointer,
168 len,
169 SQLITE_TRANSIENT(),
170 );
171 }
172 self.connection.last_error()
173 }
174
175 pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> {
176 let index = index as c_int;
177 let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) };
178
179 self.connection.last_error()?;
180 if pointer.is_null() {
181 return Ok("");
182 }
183 let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize };
184 self.connection.last_error()?;
185
186 let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) };
187 Ok(str::from_utf8(slice)?)
188 }
189
190 pub fn bind<T: Bind>(&self, value: T, index: i32) -> Result<i32> {
191 debug_assert!(index > 0);
192 value.bind(self, index)
193 }
194
195 pub fn column<T: Column>(&mut self) -> Result<T> {
196 let (result, _) = T::column(self, 0)?;
197 Ok(result)
198 }
199
200 pub fn column_type(&mut self, index: i32) -> Result<SqlType> {
201 let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT <FRIEND> FROM TABLE
202 self.connection.last_error()?;
203 match result {
204 SQLITE_INTEGER => Ok(SqlType::Integer),
205 SQLITE_FLOAT => Ok(SqlType::Float),
206 SQLITE_TEXT => Ok(SqlType::Text),
207 SQLITE_BLOB => Ok(SqlType::Blob),
208 SQLITE_NULL => Ok(SqlType::Null),
209 _ => Err(anyhow!("Column type returned was incorrect ")),
210 }
211 }
212
213 pub fn with_bindings(&mut self, bindings: impl Bind) -> Result<&mut Self> {
214 self.bind(bindings, 1)?;
215 Ok(self)
216 }
217
218 fn step(&mut self) -> Result<StepResult> {
219 unsafe {
220 match sqlite3_step(self.raw_statement) {
221 SQLITE_ROW => Ok(StepResult::Row),
222 SQLITE_DONE => Ok(StepResult::Done),
223 SQLITE_MISUSE => Ok(StepResult::Misuse),
224 other => self
225 .connection
226 .last_error()
227 .map(|_| StepResult::Other(other)),
228 }
229 }
230 }
231
232 pub fn insert(&mut self) -> Result<i64> {
233 self.exec()?;
234 Ok(self.connection.last_insert_id())
235 }
236
237 pub fn exec(&mut self) -> Result<()> {
238 fn logic(this: &mut Statement) -> Result<()> {
239 while this.step()? == StepResult::Row {}
240 Ok(())
241 }
242 let result = logic(self);
243 self.reset();
244 result
245 }
246
247 pub fn map<R>(&mut self, callback: impl FnMut(&mut Statement) -> Result<R>) -> Result<Vec<R>> {
248 fn logic<R>(
249 this: &mut Statement,
250 mut callback: impl FnMut(&mut Statement) -> Result<R>,
251 ) -> Result<Vec<R>> {
252 let mut mapped_rows = Vec::new();
253 while this.step()? == StepResult::Row {
254 mapped_rows.push(callback(this)?);
255 }
256 Ok(mapped_rows)
257 }
258
259 let result = logic(self, callback);
260 self.reset();
261 result
262 }
263
264 pub fn rows<R: Column>(&mut self) -> Result<Vec<R>> {
265 self.map(|s| s.column::<R>())
266 }
267
268 pub fn single<R>(&mut self, callback: impl FnOnce(&mut Statement) -> Result<R>) -> Result<R> {
269 fn logic<R>(
270 this: &mut Statement,
271 callback: impl FnOnce(&mut Statement) -> Result<R>,
272 ) -> Result<R> {
273 if this.step()? != StepResult::Row {
274 return Err(anyhow!(
275 "Single(Map) called with query that returns no rows."
276 ));
277 }
278 callback(this)
279 }
280 let result = logic(self, callback);
281 self.reset();
282 result
283 }
284
285 pub fn row<R: Column>(&mut self) -> Result<R> {
286 self.single(|this| this.column::<R>())
287 }
288
289 pub fn maybe<R>(
290 &mut self,
291 callback: impl FnOnce(&mut Statement) -> Result<R>,
292 ) -> Result<Option<R>> {
293 fn logic<R>(
294 this: &mut Statement,
295 callback: impl FnOnce(&mut Statement) -> Result<R>,
296 ) -> Result<Option<R>> {
297 if this.step()? != StepResult::Row {
298 return Ok(None);
299 }
300 callback(this).map(|r| Some(r))
301 }
302 let result = logic(self, callback);
303 self.reset();
304 result
305 }
306
307 pub fn maybe_row<R: Column>(&mut self) -> Result<Option<R>> {
308 self.maybe(|this| this.column::<R>())
309 }
310}
311
312impl<'a> Drop for Statement<'a> {
313 fn drop(&mut self) {
314 unsafe { sqlite3_finalize(self.raw_statement) };
315 }
316}
317
318#[cfg(test)]
319mod test {
320 use indoc::indoc;
321
322 use crate::{connection::Connection, statement::StepResult};
323
324 #[test]
325 fn blob_round_trips() {
326 let connection1 = Connection::open_memory("blob_round_trips");
327 connection1
328 .exec(indoc! {"
329 CREATE TABLE blobs (
330 data BLOB
331 );"})
332 .unwrap();
333
334 let blob = &[0, 1, 2, 4, 8, 16, 32, 64];
335
336 let mut write = connection1
337 .prepare("INSERT INTO blobs (data) VALUES (?);")
338 .unwrap();
339 write.bind_blob(1, blob).unwrap();
340 assert_eq!(write.step().unwrap(), StepResult::Done);
341
342 // Read the blob from the
343 let connection2 = Connection::open_memory("blob_round_trips");
344 let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap();
345 assert_eq!(read.step().unwrap(), StepResult::Row);
346 assert_eq!(read.column_blob(0).unwrap(), blob);
347 assert_eq!(read.step().unwrap(), StepResult::Done);
348
349 // Delete the added blob and verify its deleted on the other side
350 connection2.exec("DELETE FROM blobs;").unwrap();
351 let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap();
352 assert_eq!(read.step().unwrap(), StepResult::Done);
353 }
354
355 #[test]
356 pub fn maybe_returns_options() {
357 let connection = Connection::open_memory("maybe_returns_options");
358 connection
359 .exec(indoc! {"
360 CREATE TABLE texts (
361 text TEXT
362 );"})
363 .unwrap();
364
365 assert!(connection
366 .prepare("SELECT text FROM texts")
367 .unwrap()
368 .maybe_row::<String>()
369 .unwrap()
370 .is_none());
371
372 let text_to_insert = "This is a test";
373
374 connection
375 .prepare("INSERT INTO texts VALUES (?)")
376 .unwrap()
377 .with_bindings(text_to_insert)
378 .unwrap()
379 .exec()
380 .unwrap();
381
382 assert_eq!(
383 connection
384 .prepare("SELECT text FROM texts")
385 .unwrap()
386 .maybe_row::<String>()
387 .unwrap(),
388 Some(text_to_insert.to_string())
389 );
390 }
391}