1use proc_macro::{Delimiter, Span, TokenStream, TokenTree};
2use sqlez::thread_safe_connection::{locking_queue, ThreadSafeConnection};
3use syn::Error;
4
5lazy_static::lazy_static! {
6 static ref SQLITE: ThreadSafeConnection = {
7 ThreadSafeConnection::new(":memory:", false, None, Some(locking_queue()))
8 };
9}
10
11#[proc_macro]
12pub fn sql(tokens: TokenStream) -> TokenStream {
13 let (spans, sql) = make_sql(tokens);
14
15 let error = SQLITE.sql_has_syntax_error(sql.trim());
16 let formatted_sql = sqlformat::format(&sql, &sqlformat::QueryParams::None, Default::default());
17
18 if let Some((error, error_offset)) = error {
19 create_error(spans, error_offset, error, &formatted_sql)
20 } else {
21 format!("r#\"{}\"#", &formatted_sql).parse().unwrap()
22 }
23}
24
25fn create_error(
26 spans: Vec<(usize, Span)>,
27 error_offset: usize,
28 error: String,
29 formatted_sql: &String,
30) -> TokenStream {
31 let error_span = spans
32 .into_iter()
33 .skip_while(|(offset, _)| offset <= &error_offset)
34 .map(|(_, span)| span)
35 .next()
36 .unwrap_or_else(Span::call_site);
37 let error_text = format!("Sql Error: {}\nFor Query: {}", error, formatted_sql);
38 TokenStream::from(Error::new(error_span.into(), error_text).into_compile_error())
39}
40
41fn make_sql(tokens: TokenStream) -> (Vec<(usize, Span)>, String) {
42 let mut sql_tokens = vec![];
43 flatten_stream(tokens, &mut sql_tokens);
44 // Lookup of spans by offset at the end of the token
45 let mut spans: Vec<(usize, Span)> = Vec::new();
46 let mut sql = String::new();
47 for (token_text, span) in sql_tokens {
48 sql.push_str(&token_text);
49 spans.push((sql.len(), span));
50 }
51 (spans, sql)
52}
53
54/// This method exists to normalize the representation of groups
55/// to always include spaces between tokens. This is why we don't use the usual .to_string().
56/// This allows our token search in token_at_offset to resolve
57/// ambiguity of '(tokens)' vs. '( token )', due to sqlite requiring byte offsets
58fn flatten_stream(tokens: TokenStream, result: &mut Vec<(String, Span)>) {
59 for token_tree in tokens.into_iter() {
60 match token_tree {
61 TokenTree::Group(group) => {
62 // push open delimiter
63 result.push((open_delimiter(group.delimiter()), group.span()));
64 // recurse
65 flatten_stream(group.stream(), result);
66 // push close delimiter
67 result.push((close_delimiter(group.delimiter()), group.span()));
68 }
69 TokenTree::Ident(ident) => {
70 result.push((format!("{} ", ident), ident.span()));
71 }
72 leaf_tree => result.push((leaf_tree.to_string(), leaf_tree.span())),
73 }
74 }
75}
76
77fn open_delimiter(delimiter: Delimiter) -> String {
78 match delimiter {
79 Delimiter::Parenthesis => "( ".to_string(),
80 Delimiter::Brace => "[ ".to_string(),
81 Delimiter::Bracket => "{ ".to_string(),
82 Delimiter::None => "".to_string(),
83 }
84}
85
86fn close_delimiter(delimiter: Delimiter) -> String {
87 match delimiter {
88 Delimiter::Parenthesis => " ) ".to_string(),
89 Delimiter::Brace => " ] ".to_string(),
90 Delimiter::Bracket => " } ".to_string(),
91 Delimiter::None => "".to_string(),
92 }
93}