derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        ..
 16    } = parse_macro_input!(input);
 17
 18    let refinement_ident = format_ident!("{}Refinement", ident);
 19    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 20
 21    let fields = match data {
 22        syn::Data::Struct(syn::DataStruct {
 23            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 24            ..
 25        }) => named.into_iter().collect::<Vec<Field>>(),
 26        _ => panic!("This derive macro only supports structs with named fields"),
 27    };
 28
 29    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 30    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 31    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 32
 33    // Create trait bound that each wrapped type must implement Clone & Default
 34    let type_param_bounds: Vec<_> = wrapped_types
 35        .iter()
 36        .map(|ty| {
 37            WherePredicate::Type(PredicateType {
 38                lifetimes: None,
 39                bounded_ty: ty.clone(),
 40                colon_token: Default::default(),
 41                bounds: {
 42                    let mut punctuated = syn::punctuated::Punctuated::new();
 43                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 44                        paren_token: None,
 45                        modifier: syn::TraitBoundModifier::None,
 46                        lifetimes: None,
 47                        path: parse_quote!(Clone),
 48                    }));
 49                    punctuated.push_punct(syn::token::Add::default());
 50                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 51                        paren_token: None,
 52                        modifier: syn::TraitBoundModifier::None,
 53                        lifetimes: None,
 54                        path: parse_quote!(Default),
 55                    }));
 56                    punctuated
 57                },
 58            })
 59        })
 60        .collect();
 61
 62    // Append to where_clause or create a new one if it doesn't exist
 63    let where_clause = match where_clause.cloned() {
 64        Some(mut where_clause) => {
 65            where_clause
 66                .predicates
 67                .extend(type_param_bounds.into_iter());
 68            where_clause.clone()
 69        }
 70        None => WhereClause {
 71            where_token: Default::default(),
 72            predicates: type_param_bounds.into_iter().collect(),
 73        },
 74    };
 75
 76    let field_assignments: Vec<TokenStream2> = fields
 77        .iter()
 78        .map(|field| {
 79            let name = &field.ident;
 80            let is_refineable = is_refineable_field(field);
 81            let is_optional = is_optional_field(field);
 82
 83            if is_refineable {
 84                quote! {
 85                    self.#name.refine(&refinement.#name);
 86                }
 87            } else if is_optional {
 88                quote! {
 89                    if let Some(ref value) = &refinement.#name {
 90                        self.#name = Some(value.clone());
 91                    }
 92                }
 93            } else {
 94                quote! {
 95                    if let Some(ref value) = &refinement.#name {
 96                        self.#name = value.clone();
 97                    }
 98                }
 99            }
100        })
101        .collect();
102
103    let refinement_field_assignments: Vec<TokenStream2> = fields
104        .iter()
105        .map(|field| {
106            let name = &field.ident;
107            let is_refineable = is_refineable_field(field);
108
109            if is_refineable {
110                quote! {
111                    self.#name.refine(&refinement.#name);
112                }
113            } else {
114                quote! {
115                    if let Some(ref value) = &refinement.#name {
116                        self.#name = Some(value.clone());
117                    }
118                }
119            }
120        })
121        .collect();
122
123    let gen = quote! {
124        #[derive(Default, Clone)]
125        pub struct #refinement_ident #impl_generics {
126            #( #field_visibilities #field_names: #wrapped_types ),*
127        }
128
129        impl #impl_generics Refineable for #ident #ty_generics
130            #where_clause
131        {
132            type Refinement = #refinement_ident #ty_generics;
133
134            fn refine(&mut self, refinement: &Self::Refinement) {
135                #( #field_assignments )*
136            }
137        }
138
139        impl #impl_generics Refineable for #refinement_ident #ty_generics
140            #where_clause
141        {
142            type Refinement = #refinement_ident #ty_generics;
143
144            fn refine(&mut self, refinement: &Self::Refinement) {
145                #( #refinement_field_assignments )*
146            }
147        }
148    };
149
150    gen.into()
151}
152
153fn is_refineable_field(f: &Field) -> bool {
154    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
155}
156
157fn is_optional_field(f: &Field) -> bool {
158    if let Type::Path(typepath) = &f.ty {
159        if typepath.qself.is_none() {
160            let segments = &typepath.path.segments;
161            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
162                return true;
163            }
164        }
165    }
166    false
167}
168
169fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
170    if is_refineable_field(field) {
171        let struct_name = if let Type::Path(tp) = ty {
172            tp.path.segments.last().unwrap().ident.clone()
173        } else {
174            panic!("Expected struct type for a refineable field");
175        };
176        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
177        let generics = if let Type::Path(tp) = ty {
178            &tp.path.segments.last().unwrap().arguments
179        } else {
180            &syn::PathArguments::None
181        };
182        parse_quote!(#refinement_struct_name #generics)
183    } else if is_optional_field(field) {
184        ty.clone()
185    } else {
186        parse_quote!(Option<#ty>)
187    }
188}