1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 ..
16 } = parse_macro_input!(input);
17
18 let refinement_ident = format_ident!("{}Refinement", ident);
19 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
20
21 let fields = match data {
22 syn::Data::Struct(syn::DataStruct {
23 fields: syn::Fields::Named(FieldsNamed { named, .. }),
24 ..
25 }) => named.into_iter().collect::<Vec<Field>>(),
26 _ => panic!("This derive macro only supports structs with named fields"),
27 };
28
29 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
30 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
31 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
32
33 // Create trait bound that each wrapped type must implement Clone & Default
34 let type_param_bounds: Vec<_> = wrapped_types
35 .iter()
36 .map(|ty| {
37 WherePredicate::Type(PredicateType {
38 lifetimes: None,
39 bounded_ty: ty.clone(),
40 colon_token: Default::default(),
41 bounds: {
42 let mut punctuated = syn::punctuated::Punctuated::new();
43 punctuated.push_value(TypeParamBound::Trait(TraitBound {
44 paren_token: None,
45 modifier: syn::TraitBoundModifier::None,
46 lifetimes: None,
47 path: parse_quote!(Clone),
48 }));
49 punctuated.push_punct(syn::token::Add::default());
50 punctuated.push_value(TypeParamBound::Trait(TraitBound {
51 paren_token: None,
52 modifier: syn::TraitBoundModifier::None,
53 lifetimes: None,
54 path: parse_quote!(Default),
55 }));
56 punctuated
57 },
58 })
59 })
60 .collect();
61
62 // Append to where_clause or create a new one if it doesn't exist
63 let where_clause = match where_clause.cloned() {
64 Some(mut where_clause) => {
65 where_clause
66 .predicates
67 .extend(type_param_bounds.into_iter());
68 where_clause.clone()
69 }
70 None => WhereClause {
71 where_token: Default::default(),
72 predicates: type_param_bounds.into_iter().collect(),
73 },
74 };
75
76 let field_assignments: Vec<TokenStream2> = fields
77 .iter()
78 .map(|field| {
79 let name = &field.ident;
80 let is_refineable = is_refineable_field(field);
81 let is_optional = is_optional_field(field);
82
83 if is_refineable {
84 quote! {
85 self.#name.refine(&refinement.#name);
86 }
87 } else if is_optional {
88 quote! {
89 if let Some(ref value) = &refinement.#name {
90 self.#name = Some(value.clone());
91 }
92 }
93 } else {
94 quote! {
95 if let Some(ref value) = &refinement.#name {
96 self.#name = value.clone();
97 }
98 }
99 }
100 })
101 .collect();
102
103 let refinement_field_assignments: Vec<TokenStream2> = fields
104 .iter()
105 .map(|field| {
106 let name = &field.ident;
107 let is_refineable = is_refineable_field(field);
108
109 if is_refineable {
110 quote! {
111 self.#name.refine(&refinement.#name);
112 }
113 } else {
114 quote! {
115 if let Some(ref value) = &refinement.#name {
116 self.#name = Some(value.clone());
117 }
118 }
119 }
120 })
121 .collect();
122
123 let gen = quote! {
124 #[derive(Default, Clone)]
125 pub struct #refinement_ident #impl_generics {
126 #( #field_visibilities #field_names: #wrapped_types ),*
127 }
128
129 impl #impl_generics Refineable for #ident #ty_generics
130 #where_clause
131 {
132 type Refinement = #refinement_ident #ty_generics;
133
134 fn refine(&mut self, refinement: &Self::Refinement) {
135 #( #field_assignments )*
136 }
137 }
138
139 impl #impl_generics Refineable for #refinement_ident #ty_generics
140 #where_clause
141 {
142 type Refinement = #refinement_ident #ty_generics;
143
144 fn refine(&mut self, refinement: &Self::Refinement) {
145 #( #refinement_field_assignments )*
146 }
147 }
148 };
149
150 gen.into()
151}
152
153fn is_refineable_field(f: &Field) -> bool {
154 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
155}
156
157fn is_optional_field(f: &Field) -> bool {
158 if let Type::Path(typepath) = &f.ty {
159 if typepath.qself.is_none() {
160 let segments = &typepath.path.segments;
161 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
162 return true;
163 }
164 }
165 }
166 false
167}
168
169fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
170 if is_refineable_field(field) {
171 let struct_name = if let Type::Path(tp) = ty {
172 tp.path.segments.last().unwrap().ident.clone()
173 } else {
174 panic!("Expected struct type for a refineable field");
175 };
176 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
177 let generics = if let Type::Path(tp) = ty {
178 &tp.path.segments.last().unwrap().arguments
179 } else {
180 &syn::PathArguments::None
181 };
182 parse_quote!(#refinement_struct_name #generics)
183 } else if is_optional_field(field) {
184 ty.clone()
185 } else {
186 parse_quote!(Option<#ty>)
187 }
188}