1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let impl_debug_on_refinement = attrs
20 .iter()
21 .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
22
23 let refinement_ident = format_ident!("{}Refinement", ident);
24 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
25
26 let fields = match data {
27 syn::Data::Struct(syn::DataStruct {
28 fields: syn::Fields::Named(FieldsNamed { named, .. }),
29 ..
30 }) => named.into_iter().collect::<Vec<Field>>(),
31 _ => panic!("This derive macro only supports structs with named fields"),
32 };
33
34 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
35 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
36 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
37
38 // Create trait bound that each wrapped type must implement Clone // & Default
39 let type_param_bounds: Vec<_> = wrapped_types
40 .iter()
41 .map(|ty| {
42 WherePredicate::Type(PredicateType {
43 lifetimes: None,
44 bounded_ty: ty.clone(),
45 colon_token: Default::default(),
46 bounds: {
47 let mut punctuated = syn::punctuated::Punctuated::new();
48 punctuated.push_value(TypeParamBound::Trait(TraitBound {
49 paren_token: None,
50 modifier: syn::TraitBoundModifier::None,
51 lifetimes: None,
52 path: parse_quote!(Clone),
53 }));
54
55 // punctuated.push_punct(syn::token::Add::default());
56 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
57 // paren_token: None,
58 // modifier: syn::TraitBoundModifier::None,
59 // lifetimes: None,
60 // path: parse_quote!(Default),
61 // }));
62 punctuated
63 },
64 })
65 })
66 .collect();
67
68 // Append to where_clause or create a new one if it doesn't exist
69 let where_clause = match where_clause.cloned() {
70 Some(mut where_clause) => {
71 where_clause
72 .predicates
73 .extend(type_param_bounds.into_iter());
74 where_clause.clone()
75 }
76 None => WhereClause {
77 where_token: Default::default(),
78 predicates: type_param_bounds.into_iter().collect(),
79 },
80 };
81
82 // refinable_refine_assignments
83 // refinable_refined_assignments
84 // refinement_refine_assignments
85
86 let refineable_refine_assignments: Vec<TokenStream2> = fields
87 .iter()
88 .map(|field| {
89 let name = &field.ident;
90 let is_refineable = is_refineable_field(field);
91 let is_optional = is_optional_field(field);
92
93 if is_refineable {
94 quote! {
95 self.#name.refine(&refinement.#name);
96 }
97 } else if is_optional {
98 quote! {
99 if let Some(ref value) = &refinement.#name {
100 self.#name = Some(value.clone());
101 }
102 }
103 } else {
104 quote! {
105 if let Some(ref value) = &refinement.#name {
106 self.#name = value.clone();
107 }
108 }
109 }
110 })
111 .collect();
112
113 let refineable_refined_assignments: Vec<TokenStream2> = fields
114 .iter()
115 .map(|field| {
116 let name = &field.ident;
117 let is_refineable = is_refineable_field(field);
118 let is_optional = is_optional_field(field);
119
120 if is_refineable {
121 quote! {
122 self.#name = self.#name.refined(refinement.#name);
123 }
124 } else if is_optional {
125 quote! {
126 if let Some(value) = refinement.#name {
127 self.#name = Some(value);
128 }
129 }
130 } else {
131 quote! {
132 if let Some(value) = refinement.#name {
133 self.#name = value;
134 }
135 }
136 }
137 })
138 .collect();
139
140 let refinement_refine_assigments: Vec<TokenStream2> = fields
141 .iter()
142 .map(|field| {
143 let name = &field.ident;
144 let is_refineable = is_refineable_field(field);
145
146 if is_refineable {
147 quote! {
148 self.#name.refine(&refinement.#name);
149 }
150 } else {
151 quote! {
152 if let Some(ref value) = &refinement.#name {
153 self.#name = Some(value.clone());
154 }
155 }
156 }
157 })
158 .collect();
159
160 let refinement_refined_assignments: Vec<TokenStream2> = fields
161 .iter()
162 .map(|field| {
163 let name = &field.ident;
164 let is_refineable = is_refineable_field(field);
165
166 if is_refineable {
167 quote! {
168 self.#name = self.#name.refined(refinement.#name);
169 }
170 } else {
171 quote! {
172 if refinement.#name.is_some() {
173 self.#name = refinement.#name;
174 }
175 }
176 }
177 })
178 .collect();
179
180 let debug_impl = if impl_debug_on_refinement {
181 let refinement_field_debugs: Vec<TokenStream2> = fields
182 .iter()
183 .map(|field| {
184 let name = &field.ident;
185 quote! {
186 if self.#name.is_some() {
187 debug_struct.field(stringify!(#name), &self.#name);
188 } else {
189 all_some = false;
190 }
191 }
192 })
193 .collect();
194
195 quote! {
196 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
197 #where_clause
198 {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
201 let mut all_some = true;
202 #( #refinement_field_debugs )*
203 if all_some {
204 debug_struct.finish()
205 } else {
206 debug_struct.finish_non_exhaustive()
207 }
208 }
209 }
210 }
211 } else {
212 quote! {}
213 };
214
215 let gen = quote! {
216 #[derive(Clone)]
217 pub struct #refinement_ident #impl_generics {
218 #( #field_visibilities #field_names: #wrapped_types ),*
219 }
220
221 impl #impl_generics Refineable for #ident #ty_generics
222 #where_clause
223 {
224 type Refinement = #refinement_ident #ty_generics;
225
226 fn refine(&mut self, refinement: &Self::Refinement) {
227 #( #refineable_refine_assignments )*
228 }
229
230 fn refined(mut self, refinement: Self::Refinement) -> Self {
231 #( #refineable_refined_assignments )*
232 self
233 }
234 }
235
236 impl #impl_generics Refineable for #refinement_ident #ty_generics
237 #where_clause
238 {
239 type Refinement = #refinement_ident #ty_generics;
240
241 fn refine(&mut self, refinement: &Self::Refinement) {
242 #( #refinement_refine_assigments )*
243 }
244
245 fn refined(mut self, refinement: Self::Refinement) -> Self {
246 #( #refinement_refined_assignments )*
247 self
248 }
249 }
250
251 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
252 #where_clause
253 {
254 fn default() -> Self {
255 #refinement_ident {
256 #( #field_names: Default::default() ),*
257 }
258 }
259 }
260
261 impl #impl_generics #refinement_ident #ty_generics
262 #where_clause
263 {
264 pub fn is_some(&self) -> bool {
265 #(
266 if self.#field_names.is_some() {
267 return true;
268 }
269 )*
270 false
271 }
272 }
273
274 #debug_impl
275 };
276
277 gen.into()
278}
279
280fn is_refineable_field(f: &Field) -> bool {
281 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
282}
283
284fn is_optional_field(f: &Field) -> bool {
285 if let Type::Path(typepath) = &f.ty {
286 if typepath.qself.is_none() {
287 let segments = &typepath.path.segments;
288 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
289 return true;
290 }
291 }
292 }
293 false
294}
295
296fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
297 if is_refineable_field(field) {
298 let struct_name = if let Type::Path(tp) = ty {
299 tp.path.segments.last().unwrap().ident.clone()
300 } else {
301 panic!("Expected struct type for a refineable field");
302 };
303 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
304 let generics = if let Type::Path(tp) = ty {
305 &tp.path.segments.last().unwrap().arguments
306 } else {
307 &syn::PathArguments::None
308 };
309 parse_quote!(#refinement_struct_name #generics)
310 } else if is_optional_field(field) {
311 ty.clone()
312 } else {
313 parse_quote!(Option<#ty>)
314 }
315}