1use proc_macro::{Delimiter, Span, TokenStream, TokenTree};
2use syn::Error;
3
4#[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
5static SQLITE: std::sync::LazyLock<sqlez::thread_safe_connection::ThreadSafeConnection> =
6 std::sync::LazyLock::new(|| {
7 sqlez::thread_safe_connection::ThreadSafeConnection::new(
8 ":memory:",
9 false,
10 None,
11 Some(sqlez::thread_safe_connection::locking_queue()),
12 )
13 });
14
15#[proc_macro]
16pub fn sql(tokens: TokenStream) -> TokenStream {
17 let (spans, sql) = make_sql(tokens);
18
19 #[cfg(not(any(target_os = "linux", target_os = "freebsd")))]
20 let error = SQLITE.sql_has_syntax_error(sql.trim());
21
22 #[cfg(any(target_os = "linux", target_os = "freebsd"))]
23 let error: Option<(String, usize)> = None;
24
25 let formatted_sql = sqlformat::format(&sql, &sqlformat::QueryParams::None, Default::default());
26
27 if let Some((error, error_offset)) = error {
28 create_error(spans, error_offset, error, &formatted_sql)
29 } else {
30 format!("r#\"{}\"#", &formatted_sql).parse().unwrap()
31 }
32}
33
34fn create_error(
35 spans: Vec<(usize, Span)>,
36 error_offset: usize,
37 error: String,
38 formatted_sql: &String,
39) -> TokenStream {
40 let error_span = spans
41 .into_iter()
42 .skip_while(|(offset, _)| offset <= &error_offset)
43 .map(|(_, span)| span)
44 .next()
45 .unwrap_or_else(Span::call_site);
46 let error_text = format!("Sql Error: {}\nFor Query: {}", error, formatted_sql);
47 TokenStream::from(Error::new(error_span.into(), error_text).into_compile_error())
48}
49
50fn make_sql(tokens: TokenStream) -> (Vec<(usize, Span)>, String) {
51 let mut sql_tokens = vec![];
52 flatten_stream(tokens, &mut sql_tokens);
53 // Lookup of spans by offset at the end of the token
54 let mut spans: Vec<(usize, Span)> = Vec::new();
55 let mut sql = String::new();
56 for (token_text, span) in sql_tokens {
57 sql.push_str(&token_text);
58 spans.push((sql.len(), span));
59 }
60 (spans, sql)
61}
62
63/// This method exists to normalize the representation of groups
64/// to always include spaces between tokens. This is why we don't use the usual .to_string().
65/// This allows our token search in token_at_offset to resolve
66/// ambiguity of '(tokens)' vs. '( token )', due to sqlite requiring byte offsets
67fn flatten_stream(tokens: TokenStream, result: &mut Vec<(String, Span)>) {
68 for token_tree in tokens.into_iter() {
69 match token_tree {
70 TokenTree::Group(group) => {
71 // push open delimiter
72 result.push((open_delimiter(group.delimiter()), group.span()));
73 // recurse
74 flatten_stream(group.stream(), result);
75 // push close delimiter
76 result.push((close_delimiter(group.delimiter()), group.span()));
77 }
78 TokenTree::Ident(ident) => {
79 result.push((format!("{} ", ident), ident.span()));
80 }
81 leaf_tree => result.push((leaf_tree.to_string(), leaf_tree.span())),
82 }
83 }
84}
85
86fn open_delimiter(delimiter: Delimiter) -> String {
87 match delimiter {
88 Delimiter::Parenthesis => "( ".to_string(),
89 Delimiter::Brace => "[ ".to_string(),
90 Delimiter::Bracket => "{ ".to_string(),
91 Delimiter::None => "".to_string(),
92 }
93}
94
95fn close_delimiter(delimiter: Delimiter) -> String {
96 match delimiter {
97 Delimiter::Parenthesis => " ) ".to_string(),
98 Delimiter::Brace => " ] ".to_string(),
99 Delimiter::Bracket => " } ".to_string(),
100 Delimiter::None => "".to_string(),
101 }
102}