sqlez_macros.rs

 1use proc_macro::{Delimiter, Span, TokenStream, TokenTree};
 2use sqlez::thread_safe_connection::{locking_queue, ThreadSafeConnection};
 3use syn::Error;
 4
 5lazy_static::lazy_static! {
 6    static ref SQLITE: ThreadSafeConnection =  {
 7        ThreadSafeConnection::new(":memory:", false, None, Some(locking_queue()))
 8    };
 9}
10
11#[proc_macro]
12pub fn sql(tokens: TokenStream) -> TokenStream {
13    let (spans, sql) = make_sql(tokens);
14
15    #[cfg(not(target_os = "linux"))]
16    let error = SQLITE.sql_has_syntax_error(sql.trim());
17
18    #[cfg(target_os = "linux")]
19    let error: Option<(String, usize)> = None;
20
21    let formatted_sql = sqlformat::format(&sql, &sqlformat::QueryParams::None, Default::default());
22
23    if let Some((error, error_offset)) = error {
24        create_error(spans, error_offset, error, &formatted_sql)
25    } else {
26        format!("r#\"{}\"#", &formatted_sql).parse().unwrap()
27    }
28}
29
30fn create_error(
31    spans: Vec<(usize, Span)>,
32    error_offset: usize,
33    error: String,
34    formatted_sql: &String,
35) -> TokenStream {
36    let error_span = spans
37        .into_iter()
38        .skip_while(|(offset, _)| offset <= &error_offset)
39        .map(|(_, span)| span)
40        .next()
41        .unwrap_or_else(Span::call_site);
42    let error_text = format!("Sql Error: {}\nFor Query: {}", error, formatted_sql);
43    TokenStream::from(Error::new(error_span.into(), error_text).into_compile_error())
44}
45
46fn make_sql(tokens: TokenStream) -> (Vec<(usize, Span)>, String) {
47    let mut sql_tokens = vec![];
48    flatten_stream(tokens, &mut sql_tokens);
49    // Lookup of spans by offset at the end of the token
50    let mut spans: Vec<(usize, Span)> = Vec::new();
51    let mut sql = String::new();
52    for (token_text, span) in sql_tokens {
53        sql.push_str(&token_text);
54        spans.push((sql.len(), span));
55    }
56    (spans, sql)
57}
58
59/// This method exists to normalize the representation of groups
60/// to always include spaces between tokens. This is why we don't use the usual .to_string().
61/// This allows our token search in token_at_offset to resolve
62/// ambiguity of '(tokens)' vs. '( token )', due to sqlite requiring byte offsets
63fn flatten_stream(tokens: TokenStream, result: &mut Vec<(String, Span)>) {
64    for token_tree in tokens.into_iter() {
65        match token_tree {
66            TokenTree::Group(group) => {
67                // push open delimiter
68                result.push((open_delimiter(group.delimiter()), group.span()));
69                // recurse
70                flatten_stream(group.stream(), result);
71                // push close delimiter
72                result.push((close_delimiter(group.delimiter()), group.span()));
73            }
74            TokenTree::Ident(ident) => {
75                result.push((format!("{} ", ident), ident.span()));
76            }
77            leaf_tree => result.push((leaf_tree.to_string(), leaf_tree.span())),
78        }
79    }
80}
81
82fn open_delimiter(delimiter: Delimiter) -> String {
83    match delimiter {
84        Delimiter::Parenthesis => "( ".to_string(),
85        Delimiter::Brace => "[ ".to_string(),
86        Delimiter::Bracket => "{ ".to_string(),
87        Delimiter::None => "".to_string(),
88    }
89}
90
91fn close_delimiter(delimiter: Delimiter) -> String {
92    match delimiter {
93        Delimiter::Parenthesis => " ) ".to_string(),
94        Delimiter::Brace => " ] ".to_string(),
95        Delimiter::Bracket => " } ".to_string(),
96        Delimiter::None => "".to_string(),
97    }
98}