feature_flags_macros.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::{Span, TokenStream as TokenStream2};
  3use quote::quote;
  4use syn::{Data, DeriveInput, Fields, Ident, LitStr, parse_macro_input};
  5
  6/// Derives [`feature_flags::FeatureFlagValue`] for a unit-only enum.
  7///
  8/// Exactly one variant must be marked with `#[default]`. The default variant
  9/// is the one returned when the feature flag is announced by the server,
 10/// enabled for all users, or enabled by the staff rule — it's the "on"
 11/// value, and also the fallback for `from_wire`.
 12///
 13/// The generated impl derives:
 14///
 15/// * `all_variants` — every variant, in source order.
 16/// * `override_key` — the variant name, lower-cased with dashes between
 17///   PascalCase word boundaries (e.g. `NewWorktree` → `"new-worktree"`).
 18/// * `label` — the variant name with PascalCase boundaries expanded to
 19///   spaces (e.g. `NewWorktree` → `"New Worktree"`).
 20/// * `from_wire` — always returns the default variant, since today the
 21///   server wire format is just presence and does not carry a variant.
 22///
 23/// ## Example
 24///
 25/// ```ignore
 26/// #[derive(Clone, Copy, PartialEq, Eq, Debug, EnumFeatureFlag)]
 27/// enum Intensity {
 28///     #[default]
 29///     Low,
 30///     High,
 31/// }
 32/// ```
 33// `attributes(default)` lets users write `#[default]` on a variant even when
 34// they're not also deriving `Default`. If `#[derive(Default)]` is present in
 35// the same list, it reuses the same attribute — there's no conflict, because
 36// helper attributes aren't consumed.
 37#[proc_macro_derive(EnumFeatureFlag, attributes(default))]
 38pub fn derive_enum_feature_flag(input: TokenStream) -> TokenStream {
 39    let input = parse_macro_input!(input as DeriveInput);
 40    match expand(&input) {
 41        Ok(tokens) => tokens.into(),
 42        Err(e) => e.to_compile_error().into(),
 43    }
 44}
 45
 46fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
 47    let Data::Enum(data) = &input.data else {
 48        return Err(syn::Error::new_spanned(
 49            input,
 50            "EnumFeatureFlag can only be derived for enums",
 51        ));
 52    };
 53
 54    if data.variants.is_empty() {
 55        return Err(syn::Error::new_spanned(
 56            input,
 57            "EnumFeatureFlag requires at least one variant",
 58        ));
 59    }
 60
 61    let mut default_ident: Option<&Ident> = None;
 62    let mut variant_idents: Vec<&Ident> = Vec::new();
 63
 64    for variant in &data.variants {
 65        if !matches!(variant.fields, Fields::Unit) {
 66            return Err(syn::Error::new_spanned(
 67                variant,
 68                "EnumFeatureFlag only supports unit variants (no fields)",
 69            ));
 70        }
 71        if has_default_attr(variant) {
 72            if default_ident.is_some() {
 73                return Err(syn::Error::new_spanned(
 74                    variant,
 75                    "only one variant may be marked with #[default]",
 76                ));
 77            }
 78            default_ident = Some(&variant.ident);
 79        }
 80        variant_idents.push(&variant.ident);
 81    }
 82
 83    let Some(default_ident) = default_ident else {
 84        return Err(syn::Error::new_spanned(
 85            input,
 86            "EnumFeatureFlag requires exactly one variant to be marked with #[default]",
 87        ));
 88    };
 89
 90    let name = &input.ident;
 91    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
 92
 93    let override_key_arms = variant_idents.iter().map(|variant| {
 94        let key = LitStr::new(&to_kebab_case(&variant.to_string()), Span::call_site());
 95        quote! { #name::#variant => #key }
 96    });
 97
 98    let label_arms = variant_idents.iter().map(|variant| {
 99        let label = LitStr::new(&to_space_separated(&variant.to_string()), Span::call_site());
100        quote! { #name::#variant => #label }
101    });
102
103    let all_variants = variant_idents.iter().map(|v| quote! { #name::#v });
104
105    Ok(quote! {
106        impl #impl_generics ::std::default::Default for #name #ty_generics #where_clause {
107            fn default() -> Self {
108                #name::#default_ident
109            }
110        }
111
112        impl #impl_generics ::feature_flags::FeatureFlagValue for #name #ty_generics #where_clause {
113            fn all_variants() -> &'static [Self] {
114                &[ #( #all_variants ),* ]
115            }
116
117            fn override_key(&self) -> &'static str {
118                match self {
119                    #( #override_key_arms ),*
120                }
121            }
122
123            fn label(&self) -> &'static str {
124                match self {
125                    #( #label_arms ),*
126                }
127            }
128
129            fn from_wire(_: &str) -> ::std::option::Option<Self> {
130                ::std::option::Option::Some(#name::#default_ident)
131            }
132        }
133    })
134}
135
136fn has_default_attr(variant: &syn::Variant) -> bool {
137    variant.attrs.iter().any(|a| a.path().is_ident("default"))
138}
139
140/// Converts a PascalCase identifier to lowercase kebab-case.
141///
142/// `"NewWorktree"` → `"new-worktree"`, `"Low"` → `"low"`,
143/// `"HTTPServer"` → `"httpserver"` (acronyms are not split — keep variant
144/// names descriptive to avoid this).
145fn to_kebab_case(ident: &str) -> String {
146    let mut out = String::with_capacity(ident.len() + 4);
147    for (i, ch) in ident.chars().enumerate() {
148        if ch.is_ascii_uppercase() {
149            if i != 0 {
150                out.push('-');
151            }
152            out.push(ch.to_ascii_lowercase());
153        } else {
154            out.push(ch);
155        }
156    }
157    out
158}
159
160/// Converts a PascalCase identifier to space-separated word form for display.
161///
162/// `"NewWorktree"` → `"New Worktree"`, `"Low"` → `"Low"`.
163fn to_space_separated(ident: &str) -> String {
164    let mut out = String::with_capacity(ident.len() + 4);
165    for (i, ch) in ident.chars().enumerate() {
166        if ch.is_ascii_uppercase() && i != 0 {
167            out.push(' ');
168        }
169        out.push(ch);
170    }
171    out
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn kebab_case() {
180        assert_eq!(to_kebab_case("Low"), "low");
181        assert_eq!(to_kebab_case("NewWorktree"), "new-worktree");
182        assert_eq!(to_kebab_case("A"), "a");
183    }
184
185    #[test]
186    fn space_separated() {
187        assert_eq!(to_space_separated("Low"), "Low");
188        assert_eq!(to_space_separated("NewWorktree"), "New Worktree");
189    }
190}