1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut refinement_traits_to_derive = vec![];
23
24 if let Some(refineable_attr) = refineable_attr {
25 if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
26 for nested in meta_list.nested {
27 let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
28 continue;
29 };
30
31 if path.is_ident("Debug") {
32 impl_debug_on_refinement = true;
33 } else {
34 refinement_traits_to_derive.push(path);
35 }
36 }
37 }
38 }
39
40 let refinement_ident = format_ident!("{}Refinement", ident);
41 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
42
43 let fields = match data {
44 syn::Data::Struct(syn::DataStruct {
45 fields: syn::Fields::Named(FieldsNamed { named, .. }),
46 ..
47 }) => named.into_iter().collect::<Vec<Field>>(),
48 _ => panic!("This derive macro only supports structs with named fields"),
49 };
50
51 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
52 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
53 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
54
55 // Create trait bound that each wrapped type must implement Clone // & Default
56 let type_param_bounds: Vec<_> = wrapped_types
57 .iter()
58 .map(|ty| {
59 WherePredicate::Type(PredicateType {
60 lifetimes: None,
61 bounded_ty: ty.clone(),
62 colon_token: Default::default(),
63 bounds: {
64 let mut punctuated = syn::punctuated::Punctuated::new();
65 punctuated.push_value(TypeParamBound::Trait(TraitBound {
66 paren_token: None,
67 modifier: syn::TraitBoundModifier::None,
68 lifetimes: None,
69 path: parse_quote!(Clone),
70 }));
71
72 // punctuated.push_punct(syn::token::Add::default());
73 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
74 // paren_token: None,
75 // modifier: syn::TraitBoundModifier::None,
76 // lifetimes: None,
77 // path: parse_quote!(Default),
78 // }));
79 punctuated
80 },
81 })
82 })
83 .collect();
84
85 // Append to where_clause or create a new one if it doesn't exist
86 let where_clause = match where_clause.cloned() {
87 Some(mut where_clause) => {
88 where_clause.predicates.extend(type_param_bounds);
89 where_clause.clone()
90 }
91 None => WhereClause {
92 where_token: Default::default(),
93 predicates: type_param_bounds.into_iter().collect(),
94 },
95 };
96
97 // refinable_refine_assignments
98 // refinable_refined_assignments
99 // refinement_refine_assignments
100
101 let refineable_refine_assignments: Vec<TokenStream2> = fields
102 .iter()
103 .map(|field| {
104 let name = &field.ident;
105 let is_refineable = is_refineable_field(field);
106 let is_optional = is_optional_field(field);
107
108 if is_refineable {
109 quote! {
110 self.#name.refine(&refinement.#name);
111 }
112 } else if is_optional {
113 quote! {
114 if let Some(ref value) = &refinement.#name {
115 self.#name = Some(value.clone());
116 }
117 }
118 } else {
119 quote! {
120 if let Some(ref value) = &refinement.#name {
121 self.#name = value.clone();
122 }
123 }
124 }
125 })
126 .collect();
127
128 let refineable_refined_assignments: Vec<TokenStream2> = fields
129 .iter()
130 .map(|field| {
131 let name = &field.ident;
132 let is_refineable = is_refineable_field(field);
133 let is_optional = is_optional_field(field);
134
135 if is_refineable {
136 quote! {
137 self.#name = self.#name.refined(refinement.#name);
138 }
139 } else if is_optional {
140 quote! {
141 if let Some(value) = refinement.#name {
142 self.#name = Some(value);
143 }
144 }
145 } else {
146 quote! {
147 if let Some(value) = refinement.#name {
148 self.#name = value;
149 }
150 }
151 }
152 })
153 .collect();
154
155 let refinement_refine_assigments: Vec<TokenStream2> = fields
156 .iter()
157 .map(|field| {
158 let name = &field.ident;
159 let is_refineable = is_refineable_field(field);
160
161 if is_refineable {
162 quote! {
163 self.#name.refine(&refinement.#name);
164 }
165 } else {
166 quote! {
167 if let Some(ref value) = &refinement.#name {
168 self.#name = Some(value.clone());
169 }
170 }
171 }
172 })
173 .collect();
174
175 let refinement_refined_assigments: Vec<TokenStream2> = fields
176 .iter()
177 .map(|field| {
178 let name = &field.ident;
179 let is_refineable = is_refineable_field(field);
180
181 if is_refineable {
182 quote! {
183 self.#name = self.#name.refined(refinement.#name);
184 }
185 } else {
186 quote! {
187 if let Some(value) = refinement.#name {
188 self.#name = Some(value);
189 }
190 }
191 }
192 })
193 .collect();
194
195 let from_refinement_assigments: Vec<TokenStream2> = fields
196 .iter()
197 .map(|field| {
198 let name = &field.ident;
199 let is_refineable = is_refineable_field(field);
200 let is_optional = is_optional_field(field);
201
202 if is_refineable {
203 quote! {
204 #name: value.#name.into(),
205 }
206 } else if is_optional {
207 quote! {
208 #name: value.#name.map(|v| v.into()),
209 }
210 } else {
211 quote! {
212 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
213 }
214 }
215 })
216 .collect();
217
218 let debug_impl = if impl_debug_on_refinement {
219 let refinement_field_debugs: Vec<TokenStream2> = fields
220 .iter()
221 .map(|field| {
222 let name = &field.ident;
223 quote! {
224 if self.#name.is_some() {
225 debug_struct.field(stringify!(#name), &self.#name);
226 } else {
227 all_some = false;
228 }
229 }
230 })
231 .collect();
232
233 quote! {
234 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
235 #where_clause
236 {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
239 let mut all_some = true;
240 #( #refinement_field_debugs )*
241 if all_some {
242 debug_struct.finish()
243 } else {
244 debug_struct.finish_non_exhaustive()
245 }
246 }
247 }
248 }
249 } else {
250 quote! {}
251 };
252
253 let mut derive_stream = quote! {};
254 for trait_to_derive in refinement_traits_to_derive {
255 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
256 }
257
258 let gen = quote! {
259 #[derive(Clone)]
260 #derive_stream
261 pub struct #refinement_ident #impl_generics {
262 #( #field_visibilities #field_names: #wrapped_types ),*
263 }
264
265 impl #impl_generics Refineable for #ident #ty_generics
266 #where_clause
267 {
268 type Refinement = #refinement_ident #ty_generics;
269
270 fn refine(&mut self, refinement: &Self::Refinement) {
271 #( #refineable_refine_assignments )*
272 }
273
274 fn refined(mut self, refinement: Self::Refinement) -> Self {
275 #( #refineable_refined_assignments )*
276 self
277 }
278 }
279
280 impl #impl_generics Refineable for #refinement_ident #ty_generics
281 #where_clause
282 {
283 type Refinement = #refinement_ident #ty_generics;
284
285 fn refine(&mut self, refinement: &Self::Refinement) {
286 #( #refinement_refine_assigments )*
287 }
288
289 fn refined(mut self, refinement: Self::Refinement) -> Self {
290 #( #refinement_refined_assigments )*
291 self
292 }
293 }
294
295 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
296 #where_clause
297 {
298 fn from(value: #refinement_ident #ty_generics) -> Self {
299 Self {
300 #( #from_refinement_assigments )*
301 }
302 }
303 }
304
305 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
306 #where_clause
307 {
308 fn default() -> Self {
309 #refinement_ident {
310 #( #field_names: Default::default() ),*
311 }
312 }
313 }
314
315 impl #impl_generics #refinement_ident #ty_generics
316 #where_clause
317 {
318 pub fn is_some(&self) -> bool {
319 #(
320 if self.#field_names.is_some() {
321 return true;
322 }
323 )*
324 false
325 }
326 }
327
328 #debug_impl
329 };
330 gen.into()
331}
332
333fn is_refineable_field(f: &Field) -> bool {
334 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
335}
336
337fn is_optional_field(f: &Field) -> bool {
338 if let Type::Path(typepath) = &f.ty {
339 if typepath.qself.is_none() {
340 let segments = &typepath.path.segments;
341 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
342 return true;
343 }
344 }
345 }
346 false
347}
348
349fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
350 if is_refineable_field(field) {
351 let struct_name = if let Type::Path(tp) = ty {
352 tp.path.segments.last().unwrap().ident.clone()
353 } else {
354 panic!("Expected struct type for a refineable field");
355 };
356 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
357 let generics = if let Type::Path(tp) = ty {
358 &tp.path.segments.last().unwrap().arguments
359 } else {
360 &syn::PathArguments::None
361 };
362 parse_quote!(#refinement_struct_name #generics)
363 } else if is_optional_field(field) {
364 ty.clone()
365 } else {
366 parse_quote!(Option<#ty>)
367 }
368}