1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, Ident, LitStr, parse_macro_input};
5
6/// Derives [`feature_flags::FeatureFlagValue`] for a unit-only enum.
7///
8/// Exactly one variant must be marked with `#[default]`. The default variant
9/// is the one returned when the feature flag is announced by the server,
10/// enabled for all users, or enabled by the staff rule — it's the "on"
11/// value, and also the fallback for `from_wire`.
12///
13/// The generated impl derives:
14///
15/// * `all_variants` — every variant, in source order.
16/// * `override_key` — the variant name, lower-cased with dashes between
17/// PascalCase word boundaries (e.g. `NewWorktree` → `"new-worktree"`).
18/// * `label` — the variant name with PascalCase boundaries expanded to
19/// spaces (e.g. `NewWorktree` → `"New Worktree"`).
20/// * `from_wire` — always returns the default variant, since today the
21/// server wire format is just presence and does not carry a variant.
22///
23/// ## Example
24///
25/// ```ignore
26/// #[derive(Clone, Copy, PartialEq, Eq, Debug, EnumFeatureFlag)]
27/// enum Intensity {
28/// #[default]
29/// Low,
30/// High,
31/// }
32/// ```
33// `attributes(default)` lets users write `#[default]` on a variant even when
34// they're not also deriving `Default`. If `#[derive(Default)]` is present in
35// the same list, it reuses the same attribute — there's no conflict, because
36// helper attributes aren't consumed.
37#[proc_macro_derive(EnumFeatureFlag, attributes(default))]
38pub fn derive_enum_feature_flag(input: TokenStream) -> TokenStream {
39 let input = parse_macro_input!(input as DeriveInput);
40 match expand(&input) {
41 Ok(tokens) => tokens.into(),
42 Err(e) => e.to_compile_error().into(),
43 }
44}
45
46fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
47 let Data::Enum(data) = &input.data else {
48 return Err(syn::Error::new_spanned(
49 input,
50 "EnumFeatureFlag can only be derived for enums",
51 ));
52 };
53
54 if data.variants.is_empty() {
55 return Err(syn::Error::new_spanned(
56 input,
57 "EnumFeatureFlag requires at least one variant",
58 ));
59 }
60
61 let mut default_ident: Option<&Ident> = None;
62 let mut variant_idents: Vec<&Ident> = Vec::new();
63
64 for variant in &data.variants {
65 if !matches!(variant.fields, Fields::Unit) {
66 return Err(syn::Error::new_spanned(
67 variant,
68 "EnumFeatureFlag only supports unit variants (no fields)",
69 ));
70 }
71 if has_default_attr(variant) {
72 if default_ident.is_some() {
73 return Err(syn::Error::new_spanned(
74 variant,
75 "only one variant may be marked with #[default]",
76 ));
77 }
78 default_ident = Some(&variant.ident);
79 }
80 variant_idents.push(&variant.ident);
81 }
82
83 let Some(default_ident) = default_ident else {
84 return Err(syn::Error::new_spanned(
85 input,
86 "EnumFeatureFlag requires exactly one variant to be marked with #[default]",
87 ));
88 };
89
90 let name = &input.ident;
91 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
92
93 let override_key_arms = variant_idents.iter().map(|variant| {
94 let key = LitStr::new(&to_kebab_case(&variant.to_string()), Span::call_site());
95 quote! { #name::#variant => #key }
96 });
97
98 let label_arms = variant_idents.iter().map(|variant| {
99 let label = LitStr::new(&to_space_separated(&variant.to_string()), Span::call_site());
100 quote! { #name::#variant => #label }
101 });
102
103 let all_variants = variant_idents.iter().map(|v| quote! { #name::#v });
104
105 Ok(quote! {
106 impl #impl_generics ::std::default::Default for #name #ty_generics #where_clause {
107 fn default() -> Self {
108 #name::#default_ident
109 }
110 }
111
112 impl #impl_generics ::feature_flags::FeatureFlagValue for #name #ty_generics #where_clause {
113 fn all_variants() -> &'static [Self] {
114 &[ #( #all_variants ),* ]
115 }
116
117 fn override_key(&self) -> &'static str {
118 match self {
119 #( #override_key_arms ),*
120 }
121 }
122
123 fn label(&self) -> &'static str {
124 match self {
125 #( #label_arms ),*
126 }
127 }
128
129 fn from_wire(_: &str) -> ::std::option::Option<Self> {
130 ::std::option::Option::Some(#name::#default_ident)
131 }
132 }
133 })
134}
135
136fn has_default_attr(variant: &syn::Variant) -> bool {
137 variant.attrs.iter().any(|a| a.path().is_ident("default"))
138}
139
140/// Converts a PascalCase identifier to lowercase kebab-case.
141///
142/// `"NewWorktree"` → `"new-worktree"`, `"Low"` → `"low"`,
143/// `"HTTPServer"` → `"httpserver"` (acronyms are not split — keep variant
144/// names descriptive to avoid this).
145fn to_kebab_case(ident: &str) -> String {
146 let mut out = String::with_capacity(ident.len() + 4);
147 for (i, ch) in ident.chars().enumerate() {
148 if ch.is_ascii_uppercase() {
149 if i != 0 {
150 out.push('-');
151 }
152 out.push(ch.to_ascii_lowercase());
153 } else {
154 out.push(ch);
155 }
156 }
157 out
158}
159
160/// Converts a PascalCase identifier to space-separated word form for display.
161///
162/// `"NewWorktree"` → `"New Worktree"`, `"Low"` → `"Low"`.
163fn to_space_separated(ident: &str) -> String {
164 let mut out = String::with_capacity(ident.len() + 4);
165 for (i, ch) in ident.chars().enumerate() {
166 if ch.is_ascii_uppercase() && i != 0 {
167 out.push(' ');
168 }
169 out.push(ch);
170 }
171 out
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn kebab_case() {
180 assert_eq!(to_kebab_case("Low"), "low");
181 assert_eq!(to_kebab_case("NewWorktree"), "new-worktree");
182 assert_eq!(to_kebab_case("A"), "a");
183 }
184
185 #[test]
186 fn space_separated() {
187 assert_eq!(to_space_separated("Low"), "Low");
188 assert_eq!(to_space_separated("NewWorktree"), "New Worktree");
189 }
190}