sqlez_macros.rs

  1use proc_macro::{Delimiter, Span, TokenStream, TokenTree};
  2use syn::Error;
  3
  4#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
  5static SQLITE: std::sync::LazyLock<sqlez::thread_safe_connection::ThreadSafeConnection> =
  6    std::sync::LazyLock::new(|| {
  7        sqlez::thread_safe_connection::ThreadSafeConnection::new(
  8            ":memory:",
  9            false,
 10            None,
 11            Some(sqlez::thread_safe_connection::locking_queue()),
 12        )
 13    });
 14
 15#[proc_macro]
 16pub fn sql(tokens: TokenStream) -> TokenStream {
 17    let (spans, sql) = make_sql(tokens);
 18
 19    #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
 20    let error = SQLITE.sql_has_syntax_error(sql.trim());
 21
 22    #[cfg(any(target_os = "linux", target_os = "freebsd"))]
 23    let error: Option<(String, usize)> = None;
 24
 25    let formatted_sql = sqlformat::format(&sql, &sqlformat::QueryParams::None, Default::default());
 26
 27    if let Some((error, error_offset)) = error {
 28        create_error(spans, error_offset, error, &formatted_sql)
 29    } else {
 30        format!("r#\"{}\"#", &formatted_sql).parse().unwrap()
 31    }
 32}
 33
 34fn create_error(
 35    spans: Vec<(usize, Span)>,
 36    error_offset: usize,
 37    error: String,
 38    formatted_sql: &String,
 39) -> TokenStream {
 40    let error_span = spans
 41        .into_iter()
 42        .skip_while(|(offset, _)| offset <= &error_offset)
 43        .map(|(_, span)| span)
 44        .next()
 45        .unwrap_or_else(Span::call_site);
 46    let error_text = format!("Sql Error: {}\nFor Query: {}", error, formatted_sql);
 47    TokenStream::from(Error::new(error_span.into(), error_text).into_compile_error())
 48}
 49
 50fn make_sql(tokens: TokenStream) -> (Vec<(usize, Span)>, String) {
 51    let mut sql_tokens = vec![];
 52    flatten_stream(tokens, &mut sql_tokens);
 53    // Lookup of spans by offset at the end of the token
 54    let mut spans: Vec<(usize, Span)> = Vec::new();
 55    let mut sql = String::new();
 56    for (token_text, span) in sql_tokens {
 57        sql.push_str(&token_text);
 58        spans.push((sql.len(), span));
 59    }
 60    (spans, sql)
 61}
 62
 63/// This method exists to normalize the representation of groups
 64/// to always include spaces between tokens. This is why we don't use the usual .to_string().
 65/// This allows our token search in token_at_offset to resolve
 66/// ambiguity of '(tokens)' vs. '( token )', due to sqlite requiring byte offsets
 67fn flatten_stream(tokens: TokenStream, result: &mut Vec<(String, Span)>) {
 68    for token_tree in tokens.into_iter() {
 69        match token_tree {
 70            TokenTree::Group(group) => {
 71                // push open delimiter
 72                result.push((open_delimiter(group.delimiter()), group.span()));
 73                // recurse
 74                flatten_stream(group.stream(), result);
 75                // push close delimiter
 76                result.push((close_delimiter(group.delimiter()), group.span()));
 77            }
 78            TokenTree::Ident(ident) => {
 79                result.push((format!("{} ", ident), ident.span()));
 80            }
 81            leaf_tree => result.push((leaf_tree.to_string(), leaf_tree.span())),
 82        }
 83    }
 84}
 85
 86fn open_delimiter(delimiter: Delimiter) -> String {
 87    match delimiter {
 88        Delimiter::Parenthesis => "( ".to_string(),
 89        Delimiter::Brace => "[ ".to_string(),
 90        Delimiter::Bracket => "{ ".to_string(),
 91        Delimiter::None => "".to_string(),
 92    }
 93}
 94
 95fn close_delimiter(delimiter: Delimiter) -> String {
 96    match delimiter {
 97        Delimiter::Parenthesis => " ) ".to_string(),
 98        Delimiter::Brace => " ] ".to_string(),
 99        Delimiter::Bracket => " } ".to_string(),
100        Delimiter::None => "".to_string(),
101    }
102}