1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut derive_serialize_on_refinement = false;
23 let mut derive_deserialize_on_refinement = false;
24
25 if let Some(refineable_attr) = refineable_attr {
26 if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
27 for nested in meta_list.nested {
28 let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
29 continue;
30 };
31
32 if path.is_ident("debug") {
33 impl_debug_on_refinement = true;
34 }
35
36 if path.is_ident("serialize") {
37 derive_serialize_on_refinement = true;
38 }
39
40 if path.is_ident("deserialize") {
41 derive_deserialize_on_refinement = true;
42 }
43 }
44 }
45 }
46
47 let refinement_ident = format_ident!("{}Refinement", ident);
48 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
49
50 let fields = match data {
51 syn::Data::Struct(syn::DataStruct {
52 fields: syn::Fields::Named(FieldsNamed { named, .. }),
53 ..
54 }) => named.into_iter().collect::<Vec<Field>>(),
55 _ => panic!("This derive macro only supports structs with named fields"),
56 };
57
58 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
59 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
60 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
61
62 // Create trait bound that each wrapped type must implement Clone // & Default
63 let type_param_bounds: Vec<_> = wrapped_types
64 .iter()
65 .map(|ty| {
66 WherePredicate::Type(PredicateType {
67 lifetimes: None,
68 bounded_ty: ty.clone(),
69 colon_token: Default::default(),
70 bounds: {
71 let mut punctuated = syn::punctuated::Punctuated::new();
72 punctuated.push_value(TypeParamBound::Trait(TraitBound {
73 paren_token: None,
74 modifier: syn::TraitBoundModifier::None,
75 lifetimes: None,
76 path: parse_quote!(Clone),
77 }));
78
79 // punctuated.push_punct(syn::token::Add::default());
80 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
81 // paren_token: None,
82 // modifier: syn::TraitBoundModifier::None,
83 // lifetimes: None,
84 // path: parse_quote!(Default),
85 // }));
86 punctuated
87 },
88 })
89 })
90 .collect();
91
92 // Append to where_clause or create a new one if it doesn't exist
93 let where_clause = match where_clause.cloned() {
94 Some(mut where_clause) => {
95 where_clause
96 .predicates
97 .extend(type_param_bounds.into_iter());
98 where_clause.clone()
99 }
100 None => WhereClause {
101 where_token: Default::default(),
102 predicates: type_param_bounds.into_iter().collect(),
103 },
104 };
105
106 // refinable_refine_assignments
107 // refinable_refined_assignments
108 // refinement_refine_assignments
109
110 let refineable_refine_assignments: Vec<TokenStream2> = fields
111 .iter()
112 .map(|field| {
113 let name = &field.ident;
114 let is_refineable = is_refineable_field(field);
115 let is_optional = is_optional_field(field);
116
117 if is_refineable {
118 quote! {
119 self.#name.refine(&refinement.#name);
120 }
121 } else if is_optional {
122 quote! {
123 if let Some(ref value) = &refinement.#name {
124 self.#name = Some(value.clone());
125 }
126 }
127 } else {
128 quote! {
129 if let Some(ref value) = &refinement.#name {
130 self.#name = value.clone();
131 }
132 }
133 }
134 })
135 .collect();
136
137 let refineable_refined_assignments: Vec<TokenStream2> = fields
138 .iter()
139 .map(|field| {
140 let name = &field.ident;
141 let is_refineable = is_refineable_field(field);
142 let is_optional = is_optional_field(field);
143
144 if is_refineable {
145 quote! {
146 self.#name = self.#name.refined(refinement.#name);
147 }
148 } else if is_optional {
149 quote! {
150 if let Some(value) = refinement.#name {
151 self.#name = Some(value);
152 }
153 }
154 } else {
155 quote! {
156 if let Some(value) = refinement.#name {
157 self.#name = value;
158 }
159 }
160 }
161 })
162 .collect();
163
164 let refinement_refine_assigments: Vec<TokenStream2> = fields
165 .iter()
166 .map(|field| {
167 let name = &field.ident;
168 let is_refineable = is_refineable_field(field);
169
170 if is_refineable {
171 quote! {
172 self.#name.refine(&refinement.#name);
173 }
174 } else {
175 quote! {
176 if let Some(ref value) = &refinement.#name {
177 self.#name = Some(value.clone());
178 }
179 }
180 }
181 })
182 .collect();
183
184 let refinement_refined_assigments: Vec<TokenStream2> = fields
185 .iter()
186 .map(|field| {
187 let name = &field.ident;
188 let is_refineable = is_refineable_field(field);
189
190 if is_refineable {
191 quote! {
192 self.#name = self.#name.refined(refinement.#name);
193 }
194 } else {
195 quote! {
196 if let Some(value) = refinement.#name {
197 self.#name = Some(value);
198 }
199 }
200 }
201 })
202 .collect();
203
204 let from_refinement_assigments: Vec<TokenStream2> = fields
205 .iter()
206 .map(|field| {
207 let name = &field.ident;
208 let is_refineable = is_refineable_field(field);
209 let is_optional = is_optional_field(field);
210
211 if is_refineable {
212 quote! {
213 #name: value.#name.into(),
214 }
215 } else if is_optional {
216 quote! {
217 #name: value.#name.map(|v| v.into()),
218 }
219 } else {
220 quote! {
221 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
222 }
223 }
224 })
225 .collect();
226
227 let debug_impl = if impl_debug_on_refinement {
228 let refinement_field_debugs: Vec<TokenStream2> = fields
229 .iter()
230 .map(|field| {
231 let name = &field.ident;
232 quote! {
233 if self.#name.is_some() {
234 debug_struct.field(stringify!(#name), &self.#name);
235 } else {
236 all_some = false;
237 }
238 }
239 })
240 .collect();
241
242 quote! {
243 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
244 #where_clause
245 {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
248 let mut all_some = true;
249 #( #refinement_field_debugs )*
250 if all_some {
251 debug_struct.finish()
252 } else {
253 debug_struct.finish_non_exhaustive()
254 }
255 }
256 }
257 }
258 } else {
259 quote! {}
260 };
261
262 let derive_serialize = if derive_serialize_on_refinement {
263 quote! { #[derive(serde::Serialize)]}
264 } else {
265 quote! {}
266 };
267
268 let derive_deserialize = if derive_deserialize_on_refinement {
269 quote! { #[derive(serde::Deserialize)]}
270 } else {
271 quote! {}
272 };
273
274 let gen = quote! {
275 #[derive(Clone)]
276 #derive_serialize
277 #derive_deserialize
278 pub struct #refinement_ident #impl_generics {
279 #( #field_visibilities #field_names: #wrapped_types ),*
280 }
281
282 impl #impl_generics Refineable for #ident #ty_generics
283 #where_clause
284 {
285 type Refinement = #refinement_ident #ty_generics;
286
287 fn refine(&mut self, refinement: &Self::Refinement) {
288 #( #refineable_refine_assignments )*
289 }
290
291 fn refined(mut self, refinement: Self::Refinement) -> Self {
292 #( #refineable_refined_assignments )*
293 self
294 }
295 }
296
297 impl #impl_generics Refineable for #refinement_ident #ty_generics
298 #where_clause
299 {
300 type Refinement = #refinement_ident #ty_generics;
301
302 fn refine(&mut self, refinement: &Self::Refinement) {
303 #( #refinement_refine_assigments )*
304 }
305
306 fn refined(mut self, refinement: Self::Refinement) -> Self {
307 #( #refinement_refined_assigments )*
308 self
309 }
310 }
311
312 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
313 #where_clause
314 {
315 fn from(value: #refinement_ident #ty_generics) -> Self {
316 Self {
317 #( #from_refinement_assigments )*
318 }
319 }
320 }
321
322 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
323 #where_clause
324 {
325 fn default() -> Self {
326 #refinement_ident {
327 #( #field_names: Default::default() ),*
328 }
329 }
330 }
331
332 impl #impl_generics #refinement_ident #ty_generics
333 #where_clause
334 {
335 pub fn is_some(&self) -> bool {
336 #(
337 if self.#field_names.is_some() {
338 return true;
339 }
340 )*
341 false
342 }
343 }
344
345 #debug_impl
346 };
347 gen.into()
348}
349
350fn is_refineable_field(f: &Field) -> bool {
351 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
352}
353
354fn is_optional_field(f: &Field) -> bool {
355 if let Type::Path(typepath) = &f.ty {
356 if typepath.qself.is_none() {
357 let segments = &typepath.path.segments;
358 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
359 return true;
360 }
361 }
362 }
363 false
364}
365
366fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
367 if is_refineable_field(field) {
368 let struct_name = if let Type::Path(tp) = ty {
369 tp.path.segments.last().unwrap().ident.clone()
370 } else {
371 panic!("Expected struct type for a refineable field");
372 };
373 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
374 let generics = if let Type::Path(tp) = ty {
375 &tp.path.segments.last().unwrap().arguments
376 } else {
377 &syn::PathArguments::None
378 };
379 parse_quote!(#refinement_struct_name #generics)
380 } else if is_optional_field(field) {
381 ty.clone()
382 } else {
383 parse_quote!(Option<#ty>)
384 }
385}