enum.go

 1package apijson
 2
 3import (
 4	"fmt"
 5	"reflect"
 6	"sync"
 7)
 8
 9/********************/
10/* Validating Enums */
11/********************/
12
13type validationEntry struct {
14	field       reflect.StructField
15	nullable    bool
16	legalValues []reflect.Value
17}
18
19type validatorFunc func(reflect.Value) exactness
20
21var validators sync.Map
22var validationRegistry = map[reflect.Type][]validationEntry{}
23
24func RegisterFieldValidator[T any, V string | bool | int](fieldName string, nullable bool, values ...V) {
25	var t T
26	parentType := reflect.TypeOf(t)
27
28	if _, ok := validationRegistry[parentType]; !ok {
29		validationRegistry[parentType] = []validationEntry{}
30	}
31
32	// The following checks run at initialization time,
33	// it is impossible for them to panic if any tests pass.
34	if parentType.Kind() != reflect.Struct {
35		panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String()))
36	}
37	field, found := parentType.FieldByName(fieldName)
38	if !found {
39		panic(fmt.Sprintf("apijson: cannot initialize validator for unknown field %q in %s", fieldName, parentType.String()))
40	}
41
42	newEntry := validationEntry{field, nullable, make([]reflect.Value, len(values))}
43	for i, value := range values {
44		newEntry.legalValues[i] = reflect.ValueOf(value)
45	}
46
47	// Store the information necessary to create a validator, so that we can use it
48	// lazily create the validator function when did.
49	validationRegistry[parentType] = append(validationRegistry[parentType], newEntry)
50}
51
52// Enums are the only types which are validated
53func typeValidator(t reflect.Type) validatorFunc {
54	entry, ok := validationRegistry[t]
55	if !ok {
56		return nil
57	}
58
59	if fi, ok := validators.Load(t); ok {
60		return fi.(validatorFunc)
61	}
62
63	fi, _ := validators.LoadOrStore(t, validatorFunc(func(v reflect.Value) exactness {
64		return validateEnum(v, entry)
65	}))
66	return fi.(validatorFunc)
67}
68
69func validateEnum(v reflect.Value, entry []validationEntry) exactness {
70	if v.Kind() != reflect.Struct {
71		return loose
72	}
73
74	for _, check := range entry {
75		field := v.FieldByIndex(check.field.Index)
76		if !field.IsValid() {
77			return loose
78		}
79		for _, opt := range check.legalValues {
80			if field.Equal(opt) {
81				return exact
82			}
83		}
84	}
85
86	return loose
87}