util_macros.rs

  1use convert_case::{Case, Casing};
  2use proc_macro::TokenStream;
  3use proc_macro2::TokenStream as TokenStream2;
  4use quote::{format_ident, quote};
  5use syn::{
  6    Data, DeriveInput, Expr, ExprArray, ExprLit, Fields, Lit, LitStr, MetaNameValue, Token,
  7    parse_macro_input, punctuated::Punctuated,
  8};
  9
 10/// A macro used in tests for cross-platform path string literals in tests. On Windows it replaces
 11/// `/` with `\\` and adds `C:` to the beginning of absolute paths. On other platforms, the path is
 12/// returned unmodified.
 13///
 14/// # Example
 15/// ```rust
 16/// use util_macros::path;
 17///
 18/// let path = path!("/Users/user/file.txt");
 19/// #[cfg(target_os = "windows")]
 20/// assert_eq!(path, "C:\\Users\\user\\file.txt");
 21/// #[cfg(not(target_os = "windows"))]
 22/// assert_eq!(path, "/Users/user/file.txt");
 23/// ```
 24#[proc_macro]
 25pub fn path(input: TokenStream) -> TokenStream {
 26    let path = parse_macro_input!(input as LitStr);
 27    let mut path = path.value();
 28
 29    #[cfg(target_os = "windows")]
 30    {
 31        path = path.replace("/", "\\");
 32        if path.starts_with("\\") {
 33            path = format!("C:{}", path);
 34        }
 35    }
 36
 37    TokenStream::from(quote! {
 38        #path
 39    })
 40}
 41
 42/// This macro replaces the path prefix `file:///` with `file:///C:/` for Windows.
 43/// But if the target OS is not Windows, the URI is returned as is.
 44///
 45/// # Example
 46/// ```rust
 47/// use util_macros::uri;
 48///
 49/// let uri = uri!("file:///path/to/file");
 50/// #[cfg(target_os = "windows")]
 51/// assert_eq!(uri, "file:///C:/path/to/file");
 52/// #[cfg(not(target_os = "windows"))]
 53/// assert_eq!(uri, "file:///path/to/file");
 54/// ```
 55#[proc_macro]
 56pub fn uri(input: TokenStream) -> TokenStream {
 57    let uri = parse_macro_input!(input as LitStr);
 58    let uri = uri.value();
 59
 60    #[cfg(target_os = "windows")]
 61    let uri = uri.replace("file:///", "file:///C:/");
 62
 63    TokenStream::from(quote! {
 64        #uri
 65    })
 66}
 67
 68/// This macro replaces the line endings `\n` with `\r\n` for Windows.
 69/// But if the target OS is not Windows, the line endings are returned as is.
 70///
 71/// # Example
 72/// ```rust
 73/// use util_macros::line_endings;
 74///
 75/// let text = line_endings!("Hello\nWorld");
 76/// #[cfg(target_os = "windows")]
 77/// assert_eq!(text, "Hello\r\nWorld");
 78/// #[cfg(not(target_os = "windows"))]
 79/// assert_eq!(text, "Hello\nWorld");
 80/// ```
 81#[proc_macro]
 82pub fn line_endings(input: TokenStream) -> TokenStream {
 83    let text = parse_macro_input!(input as LitStr);
 84    let text = text.value();
 85
 86    #[cfg(target_os = "windows")]
 87    let text = text.replace("\n", "\r\n");
 88
 89    TokenStream::from(quote! {
 90        #text
 91    })
 92}
 93
 94/// Derive macro that generates an enum and implements `FieldAccessByEnum`. Only works for structs
 95/// with named fields where every field has the same type.
 96///
 97/// # Example
 98///
 99/// ```rust
100/// #[derive(FieldAccessByEnum)]
101/// #[field_access_by_enum(
102///     enum_name = "ColorField",
103///     enum_attrs = [
104///         derive(Debug, Clone, Copy, EnumIter, AsRefStr),
105///         strum(serialize_all = "snake_case")
106///     ],
107/// )]
108/// struct Theme {
109///     background: Hsla,
110///     foreground: Hsla,
111///     border_color: Hsla,
112/// }
113/// ```
114///
115/// This generates:
116/// ```rust
117/// #[derive(Debug, Clone, Copy, EnumIter, AsRefStr)]
118/// #[strum(serialize_all = "snake_case")]
119/// enum ColorField {
120///     Background,
121///     Foreground,
122///     BorderColor,
123/// }
124///
125/// impl FieldAccessByEnum for Theme {
126///     type Field = ColorField;
127///     type FieldValue = Hsla;
128///     // ... get and set methods
129/// }
130/// ```
131#[proc_macro_derive(FieldAccessByEnum, attributes(field_access_by_enum))]
132pub fn derive_field_access_by_enum(input: TokenStream) -> TokenStream {
133    let input = parse_macro_input!(input as DeriveInput);
134
135    let struct_name = &input.ident;
136
137    let mut enum_name = None;
138    let mut enum_attrs: Vec<TokenStream2> = Vec::new();
139
140    for attr in &input.attrs {
141        if attr.path().is_ident("field_access_by_enum") {
142            let name_values: Punctuated<MetaNameValue, Token![,]> =
143                attr.parse_args_with(Punctuated::parse_terminated).unwrap();
144            for name_value in name_values {
145                if name_value.path.is_ident("enum_name") {
146                    let value = name_value.value;
147                    match value {
148                        Expr::Lit(ExprLit {
149                            lit: Lit::Str(name),
150                            ..
151                        }) => enum_name = Some(name.value()),
152                        _ => panic!("Expected string literal in enum_name attribute"),
153                    }
154                } else if name_value.path.is_ident("enum_attrs") {
155                    let value = name_value.value;
156                    match value {
157                        Expr::Array(ExprArray { elems, .. }) => {
158                            for elem in elems {
159                                enum_attrs.push(quote!(#[#elem]));
160                            }
161                        }
162                        _ => panic!("Expected array literal in enum_attr attribute"),
163                    }
164                }
165            }
166        }
167    }
168    let Some(enum_name) = enum_name else {
169        panic!("#[field_access_by_enum(enum_name = \"...\")] attribute is required");
170    };
171    let enum_ident = format_ident!("{}", enum_name);
172
173    let fields = match input.data {
174        Data::Struct(data_struct) => match data_struct.fields {
175            Fields::Named(fields) => fields.named,
176            _ => panic!("FieldAccessByEnum can only be derived for structs with named fields"),
177        },
178        _ => panic!("FieldAccessByEnum can only be derived for structs"),
179    };
180
181    if fields.is_empty() {
182        panic!("FieldAccessByEnum cannot be derived for structs with no fields");
183    }
184
185    let mut enum_variants = Vec::new();
186    let mut get_match_arms = Vec::new();
187    let mut set_match_arms = Vec::new();
188    let mut field_types = Vec::new();
189
190    for field in fields.iter() {
191        let field_name = field.ident.as_ref().unwrap();
192        let variant_name = field_name.to_string().to_case(Case::Pascal);
193        let variant_ident = format_ident!("{}", variant_name);
194        let field_type = &field.ty;
195
196        enum_variants.push(variant_ident.clone());
197        field_types.push(field_type);
198
199        get_match_arms.push(quote! {
200            #enum_ident::#variant_ident => &self.#field_name,
201        });
202
203        set_match_arms.push(quote! {
204            #enum_ident::#variant_ident => self.#field_name = value,
205        });
206    }
207
208    let first_type = &field_types[0];
209    let all_same_type = field_types
210        .iter()
211        .all(|ty| quote!(#ty).to_string() == quote!(#first_type).to_string());
212    if !all_same_type {
213        panic!("Fields have different types.");
214    }
215    let field_value_type = quote! { #first_type };
216
217    let expanded = quote! {
218        #(#enum_attrs)*
219        pub enum #enum_ident {
220            #(#enum_variants),*
221        }
222
223        impl util::FieldAccessByEnum for #struct_name {
224            type Field = #enum_ident;
225            type FieldValue = #field_value_type;
226
227            fn get_field_by_enum(&self, field: Self::Field) -> &Self::FieldValue {
228                match field {
229                    #(#get_match_arms)*
230                }
231            }
232
233            fn set_field_by_enum(&mut self, field: Self::Field, value: Self::FieldValue) {
234                match field {
235                    #(#set_match_arms)*
236                }
237            }
238        }
239    };
240
241    TokenStream::from(expanded)
242}