derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let impl_debug_on_refinement = attrs
 20        .iter()
 21        .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
 22
 23    let refinement_ident = format_ident!("{}Refinement", ident);
 24    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 25
 26    let fields = match data {
 27        syn::Data::Struct(syn::DataStruct {
 28            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 29            ..
 30        }) => named.into_iter().collect::<Vec<Field>>(),
 31        _ => panic!("This derive macro only supports structs with named fields"),
 32    };
 33
 34    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 35    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 36    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 37
 38    // Create trait bound that each wrapped type must implement Clone // & Default
 39    let type_param_bounds: Vec<_> = wrapped_types
 40        .iter()
 41        .map(|ty| {
 42            WherePredicate::Type(PredicateType {
 43                lifetimes: None,
 44                bounded_ty: ty.clone(),
 45                colon_token: Default::default(),
 46                bounds: {
 47                    let mut punctuated = syn::punctuated::Punctuated::new();
 48                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 49                        paren_token: None,
 50                        modifier: syn::TraitBoundModifier::None,
 51                        lifetimes: None,
 52                        path: parse_quote!(Clone),
 53                    }));
 54
 55                    // punctuated.push_punct(syn::token::Add::default());
 56                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 57                    //     paren_token: None,
 58                    //     modifier: syn::TraitBoundModifier::None,
 59                    //     lifetimes: None,
 60                    //     path: parse_quote!(Default),
 61                    // }));
 62                    punctuated
 63                },
 64            })
 65        })
 66        .collect();
 67
 68    // Append to where_clause or create a new one if it doesn't exist
 69    let where_clause = match where_clause.cloned() {
 70        Some(mut where_clause) => {
 71            where_clause
 72                .predicates
 73                .extend(type_param_bounds.into_iter());
 74            where_clause.clone()
 75        }
 76        None => WhereClause {
 77            where_token: Default::default(),
 78            predicates: type_param_bounds.into_iter().collect(),
 79        },
 80    };
 81
 82    let field_assignments: Vec<TokenStream2> = fields
 83        .iter()
 84        .map(|field| {
 85            let name = &field.ident;
 86            let is_refineable = is_refineable_field(field);
 87            let is_optional = is_optional_field(field);
 88
 89            if is_refineable {
 90                quote! {
 91                    self.#name.refine(&refinement.#name);
 92                }
 93            } else if is_optional {
 94                quote! {
 95                    if let Some(ref value) = &refinement.#name {
 96                        self.#name = Some(value.clone());
 97                    }
 98                }
 99            } else {
100                quote! {
101                    if let Some(ref value) = &refinement.#name {
102                        self.#name = value.clone();
103                    }
104                }
105            }
106        })
107        .collect();
108
109    let refinement_field_assignments: Vec<TokenStream2> = fields
110        .iter()
111        .map(|field| {
112            let name = &field.ident;
113            let is_refineable = is_refineable_field(field);
114
115            if is_refineable {
116                quote! {
117                    self.#name.refine(&refinement.#name);
118                }
119            } else {
120                quote! {
121                    if let Some(ref value) = &refinement.#name {
122                        self.#name = Some(value.clone());
123                    }
124                }
125            }
126        })
127        .collect();
128
129    let debug_impl = if impl_debug_on_refinement {
130        let refinement_field_debugs: Vec<TokenStream2> = fields
131            .iter()
132            .map(|field| {
133                let name = &field.ident;
134                quote! {
135                    if self.#name.is_some() {
136                        debug_struct.field(stringify!(#name), &self.#name);
137                    } else {
138                        all_some = false;
139                    }
140                }
141            })
142            .collect();
143
144        quote! {
145            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
146                #where_clause
147            {
148                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
150                    let mut all_some = true;
151                    #( #refinement_field_debugs )*
152                    if all_some {
153                        debug_struct.finish()
154                    } else {
155                        debug_struct.finish_non_exhaustive()
156                    }
157                }
158            }
159        }
160    } else {
161        quote! {}
162    };
163
164    let gen = quote! {
165        #[derive(Clone)]
166        pub struct #refinement_ident #impl_generics {
167            #( #field_visibilities #field_names: #wrapped_types ),*
168        }
169
170        impl #impl_generics Refineable for #ident #ty_generics
171            #where_clause
172        {
173            type Refinement = #refinement_ident #ty_generics;
174
175            fn refine(&mut self, refinement: &Self::Refinement) {
176                #( #field_assignments )*
177            }
178        }
179
180        impl #impl_generics Refineable for #refinement_ident #ty_generics
181            #where_clause
182        {
183            type Refinement = #refinement_ident #ty_generics;
184
185            fn refine(&mut self, refinement: &Self::Refinement) {
186                #( #refinement_field_assignments )*
187            }
188        }
189
190        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
191            #where_clause
192        {
193            fn default() -> Self {
194                #refinement_ident {
195                    #( #field_names: Default::default() ),*
196                }
197            }
198        }
199
200        impl #impl_generics #refinement_ident #ty_generics
201            #where_clause
202        {
203            pub fn is_some(&self) -> bool {
204                #(
205                    if self.#field_names.is_some() {
206                        return true;
207                    }
208                )*
209                false
210            }
211        }
212
213        #debug_impl
214    };
215    gen.into()
216}
217
218fn is_refineable_field(f: &Field) -> bool {
219    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
220}
221
222fn is_optional_field(f: &Field) -> bool {
223    if let Type::Path(typepath) = &f.ty {
224        if typepath.qself.is_none() {
225            let segments = &typepath.path.segments;
226            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
227                return true;
228            }
229        }
230    }
231    false
232}
233
234fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
235    if is_refineable_field(field) {
236        let struct_name = if let Type::Path(tp) = ty {
237            tp.path.segments.last().unwrap().ident.clone()
238        } else {
239            panic!("Expected struct type for a refineable field");
240        };
241        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
242        let generics = if let Type::Path(tp) = ty {
243            &tp.path.segments.last().unwrap().arguments
244        } else {
245            &syn::PathArguments::None
246        };
247        parse_quote!(#refinement_struct_name #generics)
248    } else if is_optional_field(field) {
249        ty.clone()
250    } else {
251        parse_quote!(Option<#ty>)
252    }
253}