1package apijson
  2
  3import (
  4	"fmt"
  5	"reflect"
  6	"slices"
  7	"sync"
  8
  9	"github.com/tidwall/gjson"
 10)
 11
 12/********************/
 13/* Validating Enums */
 14/********************/
 15
 16type validationEntry struct {
 17	field       reflect.StructField
 18	required    bool
 19	legalValues struct {
 20		strings []string
 21		// 1 represents true, 0 represents false, -1 represents either
 22		bools int
 23		ints  []int64
 24	}
 25}
 26
 27type validatorFunc func(reflect.Value) exactness
 28
 29var validators sync.Map
 30var validationRegistry = map[reflect.Type][]validationEntry{}
 31
 32func RegisterFieldValidator[T any, V string | bool | int](fieldName string, values ...V) {
 33	var t T
 34	parentType := reflect.TypeOf(t)
 35
 36	if _, ok := validationRegistry[parentType]; !ok {
 37		validationRegistry[parentType] = []validationEntry{}
 38	}
 39
 40	// The following checks run at initialization time,
 41	// it is impossible for them to panic if any tests pass.
 42	if parentType.Kind() != reflect.Struct {
 43		panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String()))
 44	}
 45
 46	var field reflect.StructField
 47	found := false
 48	for i := 0; i < parentType.NumField(); i++ {
 49		ptag, ok := parseJSONStructTag(parentType.Field(i))
 50		if ok && ptag.name == fieldName {
 51			field = parentType.Field(i)
 52			found = true
 53			break
 54		}
 55	}
 56
 57	if !found {
 58		panic(fmt.Sprintf("apijson: cannot find field %s in struct %s", fieldName, parentType.String()))
 59	}
 60
 61	newEntry := validationEntry{field: field}
 62	newEntry.legalValues.bools = -1 // default to either
 63
 64	switch values := any(values).(type) {
 65	case []string:
 66		newEntry.legalValues.strings = values
 67	case []int:
 68		newEntry.legalValues.ints = make([]int64, len(values))
 69		for i, value := range values {
 70			newEntry.legalValues.ints[i] = int64(value)
 71		}
 72	case []bool:
 73		for i, value := range values {
 74			var next int
 75			if value {
 76				next = 1
 77			}
 78			if i > 0 && newEntry.legalValues.bools != next {
 79				newEntry.legalValues.bools = -1 // accept either
 80				break
 81			}
 82			newEntry.legalValues.bools = next
 83		}
 84	}
 85
 86	// Store the information necessary to create a validator, so that we can use it
 87	// lazily create the validator function when did.
 88	validationRegistry[parentType] = append(validationRegistry[parentType], newEntry)
 89}
 90
 91func (state *decoderState) validateString(v reflect.Value) {
 92	if state.validator == nil {
 93		return
 94	}
 95	if !slices.Contains(state.validator.legalValues.strings, v.String()) {
 96		state.exactness = loose
 97	}
 98}
 99
100func (state *decoderState) validateInt(v reflect.Value) {
101	if state.validator == nil {
102		return
103	}
104	if !slices.Contains(state.validator.legalValues.ints, v.Int()) {
105		state.exactness = loose
106	}
107}
108
109func (state *decoderState) validateBool(v reflect.Value) {
110	if state.validator == nil {
111		return
112	}
113	b := v.Bool()
114	if state.validator.legalValues.bools == 1 && b == false {
115		state.exactness = loose
116	} else if state.validator.legalValues.bools == 0 && b == true {
117		state.exactness = loose
118	}
119}
120
121func (state *decoderState) validateOptKind(node gjson.Result, t reflect.Type) {
122	switch node.Type {
123	case gjson.JSON:
124		state.exactness = loose
125	case gjson.Null:
126		return
127	case gjson.False, gjson.True:
128		if t.Kind() != reflect.Bool {
129			state.exactness = loose
130		}
131	case gjson.Number:
132		switch t.Kind() {
133		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
134			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
135			reflect.Float32, reflect.Float64:
136			return
137		default:
138			state.exactness = loose
139		}
140	case gjson.String:
141		if t.Kind() != reflect.String {
142			state.exactness = loose
143		}
144	}
145}