1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
6 WherePredicate, parse_macro_input, parse_quote,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path().is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut derives_serialize = false;
23 let mut refinement_traits_to_derive = vec![];
24
25 if let Some(refineable_attr) = refineable_attr {
26 let _ = refineable_attr.parse_nested_meta(|meta| {
27 if meta.path.is_ident("Debug") {
28 impl_debug_on_refinement = true;
29 } else {
30 if meta.path.is_ident("Serialize") {
31 derives_serialize = true;
32 }
33 refinement_traits_to_derive.push(meta.path);
34 }
35 Ok(())
36 });
37 }
38
39 let refinement_ident = format_ident!("{}Refinement", ident);
40 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42 let fields = match data {
43 syn::Data::Struct(syn::DataStruct {
44 fields: syn::Fields::Named(FieldsNamed { named, .. }),
45 ..
46 }) => named.into_iter().collect::<Vec<Field>>(),
47 _ => panic!("This derive macro only supports structs with named fields"),
48 };
49
50 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
51 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
52 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
53
54 let field_attributes: Vec<TokenStream2> = fields
55 .iter()
56 .map(|f| {
57 if derives_serialize {
58 if is_refineable_field(f) {
59 quote! { #[serde(default, skip_serializing_if = "::refineable::IsEmpty::is_empty")] }
60 } else {
61 quote! { #[serde(skip_serializing_if = "::std::option::Option::is_none")] }
62 }
63 } else {
64 quote! {}
65 }
66 })
67 .collect();
68
69 // Create trait bound that each wrapped type must implement Clone
70 let type_param_bounds: Vec<_> = wrapped_types
71 .iter()
72 .map(|ty| {
73 WherePredicate::Type(PredicateType {
74 lifetimes: None,
75 bounded_ty: ty.clone(),
76 colon_token: Default::default(),
77 bounds: {
78 let mut punctuated = syn::punctuated::Punctuated::new();
79 punctuated.push_value(TypeParamBound::Trait(TraitBound {
80 paren_token: None,
81 modifier: syn::TraitBoundModifier::None,
82 lifetimes: None,
83 path: parse_quote!(Clone),
84 }));
85
86 punctuated
87 },
88 })
89 })
90 .collect();
91
92 // Append to where_clause or create a new one if it doesn't exist
93 let where_clause = match where_clause.cloned() {
94 Some(mut where_clause) => {
95 where_clause.predicates.extend(type_param_bounds);
96 where_clause.clone()
97 }
98 None => WhereClause {
99 where_token: Default::default(),
100 predicates: type_param_bounds.into_iter().collect(),
101 },
102 };
103
104 let refineable_refine_assignments: Vec<TokenStream2> = fields
105 .iter()
106 .map(|field| {
107 let name = &field.ident;
108 let is_refineable = is_refineable_field(field);
109 let is_optional = is_optional_field(field);
110
111 if is_refineable {
112 quote! {
113 self.#name.refine(&refinement.#name);
114 }
115 } else if is_optional {
116 quote! {
117 if let Some(value) = &refinement.#name {
118 self.#name = Some(value.clone());
119 }
120 }
121 } else {
122 quote! {
123 if let Some(value) = &refinement.#name {
124 self.#name = value.clone();
125 }
126 }
127 }
128 })
129 .collect();
130
131 let refineable_refined_assignments: Vec<TokenStream2> = fields
132 .iter()
133 .map(|field| {
134 let name = &field.ident;
135 let is_refineable = is_refineable_field(field);
136 let is_optional = is_optional_field(field);
137
138 if is_refineable {
139 quote! {
140 self.#name = self.#name.refined(refinement.#name);
141 }
142 } else if is_optional {
143 quote! {
144 if let Some(value) = refinement.#name {
145 self.#name = Some(value);
146 }
147 }
148 } else {
149 quote! {
150 if let Some(value) = refinement.#name {
151 self.#name = value;
152 }
153 }
154 }
155 })
156 .collect();
157
158 let refinement_refine_assignments: Vec<TokenStream2> = fields
159 .iter()
160 .map(|field| {
161 let name = &field.ident;
162 let is_refineable = is_refineable_field(field);
163
164 if is_refineable {
165 quote! {
166 self.#name.refine(&refinement.#name);
167 }
168 } else {
169 quote! {
170 if let Some(value) = &refinement.#name {
171 self.#name = Some(value.clone());
172 }
173 }
174 }
175 })
176 .collect();
177
178 let refinement_refined_assignments: Vec<TokenStream2> = fields
179 .iter()
180 .map(|field| {
181 let name = &field.ident;
182 let is_refineable = is_refineable_field(field);
183
184 if is_refineable {
185 quote! {
186 self.#name = self.#name.refined(refinement.#name);
187 }
188 } else {
189 quote! {
190 if let Some(value) = refinement.#name {
191 self.#name = Some(value);
192 }
193 }
194 }
195 })
196 .collect();
197
198 let from_refinement_assignments: Vec<TokenStream2> = fields
199 .iter()
200 .map(|field| {
201 let name = &field.ident;
202 let is_refineable = is_refineable_field(field);
203 let is_optional = is_optional_field(field);
204
205 if is_refineable {
206 quote! {
207 #name: value.#name.into(),
208 }
209 } else if is_optional {
210 quote! {
211 #name: value.#name.map(|v| v.into()),
212 }
213 } else {
214 quote! {
215 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
216 }
217 }
218 })
219 .collect();
220
221 let debug_impl = if impl_debug_on_refinement {
222 let refinement_field_debugs: Vec<TokenStream2> = fields
223 .iter()
224 .map(|field| {
225 let name = &field.ident;
226 quote! {
227 if self.#name.is_some() {
228 debug_struct.field(stringify!(#name), &self.#name);
229 } else {
230 all_some = false;
231 }
232 }
233 })
234 .collect();
235
236 quote! {
237 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
238 #where_clause
239 {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
242 let mut all_some = true;
243 #( #refinement_field_debugs )*
244 if all_some {
245 debug_struct.finish()
246 } else {
247 debug_struct.finish_non_exhaustive()
248 }
249 }
250 }
251 }
252 } else {
253 quote! {}
254 };
255
256 let refinement_is_empty_conditions: Vec<TokenStream2> = fields
257 .iter()
258 .enumerate()
259 .map(|(i, field)| {
260 let name = &field.ident;
261
262 let condition = if is_refineable_field(field) {
263 quote! { self.#name.is_empty() }
264 } else {
265 quote! { self.#name.is_none() }
266 };
267
268 if i < fields.len() - 1 {
269 quote! { #condition && }
270 } else {
271 condition
272 }
273 })
274 .collect();
275
276 let refineable_is_superset_conditions: Vec<TokenStream2> = fields
277 .iter()
278 .map(|field| {
279 let name = &field.ident;
280 let is_refineable = is_refineable_field(field);
281 let is_optional = is_optional_field(field);
282
283 if is_refineable {
284 quote! {
285 if !self.#name.is_superset_of(&refinement.#name) {
286 return false;
287 }
288 }
289 } else if is_optional {
290 quote! {
291 if refinement.#name.is_some() && &self.#name != &refinement.#name {
292 return false;
293 }
294 }
295 } else {
296 quote! {
297 if let Some(refinement_value) = &refinement.#name {
298 if &self.#name != refinement_value {
299 return false;
300 }
301 }
302 }
303 }
304 })
305 .collect();
306
307 let refinement_is_superset_conditions: Vec<TokenStream2> = fields
308 .iter()
309 .map(|field| {
310 let name = &field.ident;
311 let is_refineable = is_refineable_field(field);
312
313 if is_refineable {
314 quote! {
315 if !self.#name.is_superset_of(&refinement.#name) {
316 return false;
317 }
318 }
319 } else {
320 quote! {
321 if refinement.#name.is_some() && &self.#name != &refinement.#name {
322 return false;
323 }
324 }
325 }
326 })
327 .collect();
328
329 let refineable_subtract_assignments: Vec<TokenStream2> = fields
330 .iter()
331 .map(|field| {
332 let name = &field.ident;
333 let is_refineable = is_refineable_field(field);
334 let is_optional = is_optional_field(field);
335
336 if is_refineable {
337 quote! {
338 #name: self.#name.subtract(&refinement.#name),
339 }
340 } else if is_optional {
341 quote! {
342 #name: if &self.#name == &refinement.#name {
343 None
344 } else {
345 self.#name.clone()
346 },
347 }
348 } else {
349 quote! {
350 #name: if let Some(refinement_value) = &refinement.#name {
351 if &self.#name == refinement_value {
352 None
353 } else {
354 Some(self.#name.clone())
355 }
356 } else {
357 Some(self.#name.clone())
358 },
359 }
360 }
361 })
362 .collect();
363
364 let refinement_subtract_assignments: Vec<TokenStream2> = fields
365 .iter()
366 .map(|field| {
367 let name = &field.ident;
368 let is_refineable = is_refineable_field(field);
369
370 if is_refineable {
371 quote! {
372 #name: self.#name.subtract(&refinement.#name),
373 }
374 } else {
375 quote! {
376 #name: if &self.#name == &refinement.#name {
377 None
378 } else {
379 self.#name.clone()
380 },
381 }
382 }
383 })
384 .collect();
385
386 let mut derive_stream = quote! {};
387 for trait_to_derive in refinement_traits_to_derive {
388 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
389 }
390
391 let r#gen = quote! {
392 /// A refinable version of [`#ident`], see that documentation for details.
393 #[derive(Clone)]
394 #derive_stream
395 pub struct #refinement_ident #impl_generics {
396 #(
397 #[allow(missing_docs)]
398 #field_attributes
399 #field_visibilities #field_names: #wrapped_types
400 ),*
401 }
402
403 impl #impl_generics Refineable for #ident #ty_generics
404 #where_clause
405 {
406 type Refinement = #refinement_ident #ty_generics;
407
408 fn refine(&mut self, refinement: &Self::Refinement) {
409 #( #refineable_refine_assignments )*
410 }
411
412 fn refined(mut self, refinement: Self::Refinement) -> Self {
413 #( #refineable_refined_assignments )*
414 self
415 }
416
417 fn is_superset_of(&self, refinement: &Self::Refinement) -> bool
418 {
419 #( #refineable_is_superset_conditions )*
420 true
421 }
422
423 fn subtract(&self, refinement: &Self::Refinement) -> Self::Refinement
424 {
425 #refinement_ident {
426 #( #refineable_subtract_assignments )*
427 }
428 }
429 }
430
431 impl #impl_generics Refineable for #refinement_ident #ty_generics
432 #where_clause
433 {
434 type Refinement = #refinement_ident #ty_generics;
435
436 fn refine(&mut self, refinement: &Self::Refinement) {
437 #( #refinement_refine_assignments )*
438 }
439
440 fn refined(mut self, refinement: Self::Refinement) -> Self {
441 #( #refinement_refined_assignments )*
442 self
443 }
444
445 fn is_superset_of(&self, refinement: &Self::Refinement) -> bool
446 {
447 #( #refinement_is_superset_conditions )*
448 true
449 }
450
451 fn subtract(&self, refinement: &Self::Refinement) -> Self::Refinement
452 {
453 #refinement_ident {
454 #( #refinement_subtract_assignments )*
455 }
456 }
457 }
458
459 impl #impl_generics ::refineable::IsEmpty for #refinement_ident #ty_generics
460 #where_clause
461 {
462 fn is_empty(&self) -> bool {
463 #( #refinement_is_empty_conditions )*
464 }
465 }
466
467 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
468 #where_clause
469 {
470 fn from(value: #refinement_ident #ty_generics) -> Self {
471 Self {
472 #( #from_refinement_assignments )*
473 }
474 }
475 }
476
477 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
478 #where_clause
479 {
480 fn default() -> Self {
481 #refinement_ident {
482 #( #field_names: Default::default() ),*
483 }
484 }
485 }
486
487 impl #impl_generics #refinement_ident #ty_generics
488 #where_clause
489 {
490 /// Returns `true` if all fields are `Some`
491 pub fn is_some(&self) -> bool {
492 #(
493 if self.#field_names.is_some() {
494 return true;
495 }
496 )*
497 false
498 }
499 }
500
501 #debug_impl
502 };
503 r#gen.into()
504}
505
506fn is_refineable_field(f: &Field) -> bool {
507 f.attrs
508 .iter()
509 .any(|attr| attr.path().is_ident("refineable"))
510}
511
512fn is_optional_field(f: &Field) -> bool {
513 if let Type::Path(typepath) = &f.ty
514 && typepath.qself.is_none()
515 {
516 let segments = &typepath.path.segments;
517 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
518 return true;
519 }
520 }
521 false
522}
523
524fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
525 if is_refineable_field(field) {
526 let struct_name = if let Type::Path(tp) = ty {
527 tp.path.segments.last().unwrap().ident.clone()
528 } else {
529 panic!("Expected struct type for a refineable field");
530 };
531 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
532 let generics = if let Type::Path(tp) = ty {
533 &tp.path.segments.last().unwrap().arguments
534 } else {
535 &syn::PathArguments::None
536 };
537 parse_quote!(#refinement_struct_name #generics)
538 } else if is_optional_field(field) {
539 ty.clone()
540 } else {
541 parse_quote!(Option<#ty>)
542 }
543}