1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
6 WherePredicate, parse_macro_input, parse_quote,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path().is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut refinement_traits_to_derive = vec![];
23
24 if let Some(refineable_attr) = refineable_attr {
25 let _ = refineable_attr.parse_nested_meta(|meta| {
26 if meta.path.is_ident("Debug") {
27 impl_debug_on_refinement = true;
28 } else {
29 refinement_traits_to_derive.push(meta.path);
30 }
31 Ok(())
32 });
33 }
34
35 let refinement_ident = format_ident!("{}Refinement", ident);
36 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37
38 let fields = match data {
39 syn::Data::Struct(syn::DataStruct {
40 fields: syn::Fields::Named(FieldsNamed { named, .. }),
41 ..
42 }) => named.into_iter().collect::<Vec<Field>>(),
43 _ => panic!("This derive macro only supports structs with named fields"),
44 };
45
46 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
47 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
48 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
49
50 // Create trait bound that each wrapped type must implement Clone // & Default
51 let type_param_bounds: Vec<_> = wrapped_types
52 .iter()
53 .map(|ty| {
54 WherePredicate::Type(PredicateType {
55 lifetimes: None,
56 bounded_ty: ty.clone(),
57 colon_token: Default::default(),
58 bounds: {
59 let mut punctuated = syn::punctuated::Punctuated::new();
60 punctuated.push_value(TypeParamBound::Trait(TraitBound {
61 paren_token: None,
62 modifier: syn::TraitBoundModifier::None,
63 lifetimes: None,
64 path: parse_quote!(Clone),
65 }));
66
67 punctuated
68 },
69 })
70 })
71 .collect();
72
73 // Append to where_clause or create a new one if it doesn't exist
74 let where_clause = match where_clause.cloned() {
75 Some(mut where_clause) => {
76 where_clause.predicates.extend(type_param_bounds);
77 where_clause.clone()
78 }
79 None => WhereClause {
80 where_token: Default::default(),
81 predicates: type_param_bounds.into_iter().collect(),
82 },
83 };
84
85 let refineable_refine_assignments: Vec<TokenStream2> = fields
86 .iter()
87 .map(|field| {
88 let name = &field.ident;
89 let is_refineable = is_refineable_field(field);
90 let is_optional = is_optional_field(field);
91
92 if is_refineable {
93 quote! {
94 self.#name.refine(&refinement.#name);
95 }
96 } else if is_optional {
97 quote! {
98 if let Some(value) = &refinement.#name {
99 self.#name = Some(value.clone());
100 }
101 }
102 } else {
103 quote! {
104 if let Some(value) = &refinement.#name {
105 self.#name = value.clone();
106 }
107 }
108 }
109 })
110 .collect();
111
112 let refineable_refined_assignments: Vec<TokenStream2> = fields
113 .iter()
114 .map(|field| {
115 let name = &field.ident;
116 let is_refineable = is_refineable_field(field);
117 let is_optional = is_optional_field(field);
118
119 if is_refineable {
120 quote! {
121 self.#name = self.#name.refined(refinement.#name);
122 }
123 } else if is_optional {
124 quote! {
125 if let Some(value) = refinement.#name {
126 self.#name = Some(value);
127 }
128 }
129 } else {
130 quote! {
131 if let Some(value) = refinement.#name {
132 self.#name = value;
133 }
134 }
135 }
136 })
137 .collect();
138
139 let refinement_refine_assignments: Vec<TokenStream2> = fields
140 .iter()
141 .map(|field| {
142 let name = &field.ident;
143 let is_refineable = is_refineable_field(field);
144
145 if is_refineable {
146 quote! {
147 self.#name.refine(&refinement.#name);
148 }
149 } else {
150 quote! {
151 if let Some(value) = &refinement.#name {
152 self.#name = Some(value.clone());
153 }
154 }
155 }
156 })
157 .collect();
158
159 let refinement_refined_assignments: Vec<TokenStream2> = fields
160 .iter()
161 .map(|field| {
162 let name = &field.ident;
163 let is_refineable = is_refineable_field(field);
164
165 if is_refineable {
166 quote! {
167 self.#name = self.#name.refined(refinement.#name);
168 }
169 } else {
170 quote! {
171 if let Some(value) = refinement.#name {
172 self.#name = Some(value);
173 }
174 }
175 }
176 })
177 .collect();
178
179 let from_refinement_assignments: Vec<TokenStream2> = fields
180 .iter()
181 .map(|field| {
182 let name = &field.ident;
183 let is_refineable = is_refineable_field(field);
184 let is_optional = is_optional_field(field);
185
186 if is_refineable {
187 quote! {
188 #name: value.#name.into(),
189 }
190 } else if is_optional {
191 quote! {
192 #name: value.#name.map(|v| v.into()),
193 }
194 } else {
195 quote! {
196 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
197 }
198 }
199 })
200 .collect();
201
202 let debug_impl = if impl_debug_on_refinement {
203 let refinement_field_debugs: Vec<TokenStream2> = fields
204 .iter()
205 .map(|field| {
206 let name = &field.ident;
207 quote! {
208 if self.#name.is_some() {
209 debug_struct.field(stringify!(#name), &self.#name);
210 } else {
211 all_some = false;
212 }
213 }
214 })
215 .collect();
216
217 quote! {
218 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
219 #where_clause
220 {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
223 let mut all_some = true;
224 #( #refinement_field_debugs )*
225 if all_some {
226 debug_struct.finish()
227 } else {
228 debug_struct.finish_non_exhaustive()
229 }
230 }
231 }
232 }
233 } else {
234 quote! {}
235 };
236
237 let mut derive_stream = quote! {};
238 for trait_to_derive in refinement_traits_to_derive {
239 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
240 }
241
242 let r#gen = quote! {
243 /// A refinable version of [`#ident`], see that documentation for details.
244 #[derive(Clone)]
245 #derive_stream
246 pub struct #refinement_ident #impl_generics {
247 #(
248 #[allow(missing_docs)]
249 #field_visibilities #field_names: #wrapped_types
250 ),*
251 }
252
253 impl #impl_generics Refineable for #ident #ty_generics
254 #where_clause
255 {
256 type Refinement = #refinement_ident #ty_generics;
257
258 fn refine(&mut self, refinement: &Self::Refinement) {
259 #( #refineable_refine_assignments )*
260 }
261
262 fn refined(mut self, refinement: Self::Refinement) -> Self {
263 #( #refineable_refined_assignments )*
264 self
265 }
266 }
267
268 impl #impl_generics Refineable for #refinement_ident #ty_generics
269 #where_clause
270 {
271 type Refinement = #refinement_ident #ty_generics;
272
273 fn refine(&mut self, refinement: &Self::Refinement) {
274 #( #refinement_refine_assignments )*
275 }
276
277 fn refined(mut self, refinement: Self::Refinement) -> Self {
278 #( #refinement_refined_assignments )*
279 self
280 }
281 }
282
283 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
284 #where_clause
285 {
286 fn from(value: #refinement_ident #ty_generics) -> Self {
287 Self {
288 #( #from_refinement_assignments )*
289 }
290 }
291 }
292
293 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
294 #where_clause
295 {
296 fn default() -> Self {
297 #refinement_ident {
298 #( #field_names: Default::default() ),*
299 }
300 }
301 }
302
303 impl #impl_generics #refinement_ident #ty_generics
304 #where_clause
305 {
306 /// Returns `true` if all fields are `Some`
307 pub fn is_some(&self) -> bool {
308 #(
309 if self.#field_names.is_some() {
310 return true;
311 }
312 )*
313 false
314 }
315 }
316
317 #debug_impl
318 };
319 r#gen.into()
320}
321
322fn is_refineable_field(f: &Field) -> bool {
323 f.attrs
324 .iter()
325 .any(|attr| attr.path().is_ident("refineable"))
326}
327
328fn is_optional_field(f: &Field) -> bool {
329 if let Type::Path(typepath) = &f.ty {
330 if typepath.qself.is_none() {
331 let segments = &typepath.path.segments;
332 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
333 return true;
334 }
335 }
336 }
337 false
338}
339
340fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
341 if is_refineable_field(field) {
342 let struct_name = if let Type::Path(tp) = ty {
343 tp.path.segments.last().unwrap().ident.clone()
344 } else {
345 panic!("Expected struct type for a refineable field");
346 };
347 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
348 let generics = if let Type::Path(tp) = ty {
349 &tp.path.segments.last().unwrap().arguments
350 } else {
351 &syn::PathArguments::None
352 };
353 parse_quote!(#refinement_struct_name #generics)
354 } else if is_optional_field(field) {
355 ty.clone()
356 } else {
357 parse_quote!(Option<#ty>)
358 }
359}