decoder.go

  1package apijson
  2
  3import (
  4	"encoding/json"
  5	"errors"
  6	"fmt"
  7	"reflect"
  8	"strconv"
  9	"sync"
 10	"time"
 11	"unsafe"
 12
 13	"github.com/tidwall/gjson"
 14)
 15
 16// decoders is a synchronized map with roughly the following type:
 17// map[reflect.Type]decoderFunc
 18var decoders sync.Map
 19
 20// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded
 21// data and stores it in the given pointer.
 22func Unmarshal(raw []byte, to any) error {
 23	d := &decoderBuilder{dateFormat: time.RFC3339}
 24	return d.unmarshal(raw, to)
 25}
 26
 27// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the
 28// root element. Useful if a struct's UnmarshalJSON is overrode to use the
 29// behavior of this encoder versus the standard library.
 30func UnmarshalRoot(raw []byte, to any) error {
 31	d := &decoderBuilder{dateFormat: time.RFC3339, root: true}
 32	return d.unmarshal(raw, to)
 33}
 34
 35// decoderBuilder contains the 'compile-time' state of the decoder.
 36type decoderBuilder struct {
 37	// Whether or not this is the first element and called by [UnmarshalRoot], see
 38	// the documentation there to see why this is necessary.
 39	root bool
 40	// The dateFormat (a format string for [time.Format]) which is chosen by the
 41	// last struct tag that was seen.
 42	dateFormat string
 43}
 44
 45// decoderState contains the 'run-time' state of the decoder.
 46type decoderState struct {
 47	strict    bool
 48	exactness exactness
 49}
 50
 51// Exactness refers to how close to the type the result was if deserialization
 52// was successful. This is useful in deserializing unions, where you want to try
 53// each entry, first with strict, then with looser validation, without actually
 54// having to do a lot of redundant work by marshalling twice (or maybe even more
 55// times).
 56type exactness int8
 57
 58const (
 59	// Some values had to fudged a bit, for example by converting a string to an
 60	// int, or an enum with extra values.
 61	loose exactness = iota
 62	// There are some extra arguments, but other wise it matches the union.
 63	extras
 64	// Exactly right.
 65	exact
 66)
 67
 68type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error
 69
 70type decoderField struct {
 71	tag    parsedStructTag
 72	fn     decoderFunc
 73	idx    []int
 74	goname string
 75}
 76
 77type decoderEntry struct {
 78	reflect.Type
 79	dateFormat string
 80	root       bool
 81}
 82
 83func (d *decoderBuilder) unmarshal(raw []byte, to any) error {
 84	value := reflect.ValueOf(to).Elem()
 85	result := gjson.ParseBytes(raw)
 86	if !value.IsValid() {
 87		return fmt.Errorf("apijson: cannot marshal into invalid value")
 88	}
 89	return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact})
 90}
 91
 92func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
 93	entry := decoderEntry{
 94		Type:       t,
 95		dateFormat: d.dateFormat,
 96		root:       d.root,
 97	}
 98
 99	if fi, ok := decoders.Load(entry); ok {
100		return fi.(decoderFunc)
101	}
102
103	// To deal with recursive types, populate the map with an
104	// indirect func before we build it. This type waits on the
105	// real func (f) to be ready and then calls it. This indirect
106	// func is only used for recursive types.
107	var (
108		wg sync.WaitGroup
109		f  decoderFunc
110	)
111	wg.Add(1)
112	fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error {
113		wg.Wait()
114		return f(node, v, state)
115	}))
116	if loaded {
117		return fi.(decoderFunc)
118	}
119
120	// Compute the real decoder and replace the indirect func with it.
121	f = d.newTypeDecoder(t)
122	wg.Done()
123	decoders.Store(entry, f)
124	return f
125}
126
127func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
128	return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
129}
130
131func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
132	if v.Kind() == reflect.Pointer && v.CanSet() {
133		v.Set(reflect.New(v.Type().Elem()))
134	}
135	return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
136}
137
138func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
139	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
140		return d.newTimeTypeDecoder(t)
141	}
142	if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
143		return unmarshalerDecoder
144	}
145	if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
146		if _, ok := unionVariants[t]; !ok {
147			return indirectUnmarshalerDecoder
148		}
149	}
150	d.root = false
151
152	if _, ok := unionRegistry[t]; ok {
153		return d.newUnionDecoder(t)
154	}
155
156	switch t.Kind() {
157	case reflect.Pointer:
158		inner := t.Elem()
159		innerDecoder := d.typeDecoder(inner)
160
161		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
162			if !v.IsValid() {
163				return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v)
164			}
165
166			newValue := reflect.New(inner).Elem()
167			err := innerDecoder(n, newValue, state)
168			if err != nil {
169				return err
170			}
171
172			v.Set(newValue.Addr())
173			return nil
174		}
175	case reflect.Struct:
176		if isEmbeddedUnion(t) {
177			return d.newEmbeddedUnionDecoder(t)
178		}
179		return d.newStructTypeDecoder(t)
180	case reflect.Array:
181		fallthrough
182	case reflect.Slice:
183		return d.newArrayTypeDecoder(t)
184	case reflect.Map:
185		return d.newMapDecoder(t)
186	case reflect.Interface:
187		return func(node gjson.Result, value reflect.Value, state *decoderState) error {
188			if !value.IsValid() {
189				return fmt.Errorf("apijson: unexpected invalid value %+#v", value)
190			}
191			if node.Value() != nil && value.CanSet() {
192				value.Set(reflect.ValueOf(node.Value()))
193			}
194			return nil
195		}
196	default:
197		return d.newPrimitiveTypeDecoder(t)
198	}
199}
200
201// newUnionDecoder returns a decoderFunc that deserializes into a union using an
202// algorithm roughly similar to Pydantic's [smart algorithm].
203//
204// Conceptually this is equivalent to choosing the best schema based on how 'exact'
205// the deserialization is for each of the schemas.
206//
207// If there is a tie in the level of exactness, then the tie is broken
208// left-to-right.
209//
210// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
211func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
212	unionEntry, ok := unionRegistry[t]
213	if !ok {
214		panic("apijson: couldn't find union of type " + t.String() + " in union registry")
215	}
216	decoders := []decoderFunc{}
217	for _, variant := range unionEntry.variants {
218		decoder := d.typeDecoder(variant.Type)
219		decoders = append(decoders, decoder)
220	}
221	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
222		// If there is a discriminator match, circumvent the exactness logic entirely
223		for idx, variant := range unionEntry.variants {
224			decoder := decoders[idx]
225			if variant.TypeFilter != n.Type {
226				continue
227			}
228
229			if len(unionEntry.discriminatorKey) != 0 {
230				discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
231				if discriminatorValue == variant.DiscriminatorValue {
232					inner := reflect.New(variant.Type).Elem()
233					err := decoder(n, inner, state)
234					v.Set(inner)
235					return err
236				}
237			}
238		}
239
240		// Set bestExactness to worse than loose
241		bestExactness := loose - 1
242		for idx, variant := range unionEntry.variants {
243			decoder := decoders[idx]
244			if variant.TypeFilter != n.Type {
245				continue
246			}
247			sub := decoderState{strict: state.strict, exactness: exact}
248			inner := reflect.New(variant.Type).Elem()
249			err := decoder(n, inner, &sub)
250			if err != nil {
251				continue
252			}
253			if sub.exactness == exact {
254				v.Set(inner)
255				return nil
256			}
257			if sub.exactness > bestExactness {
258				v.Set(inner)
259				bestExactness = sub.exactness
260			}
261		}
262
263		if bestExactness < loose {
264			return errors.New("apijson: was not able to coerce type as union")
265		}
266
267		if guardStrict(state, bestExactness != exact) {
268			return errors.New("apijson: was not able to coerce type as union strictly")
269		}
270
271		return nil
272	}
273}
274
275func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc {
276	keyType := t.Key()
277	itemType := t.Elem()
278	itemDecoder := d.typeDecoder(itemType)
279
280	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
281		mapValue := reflect.MakeMapWithSize(t, len(node.Map()))
282
283		node.ForEach(func(key, value gjson.Result) bool {
284			// It's fine for us to just use `ValueOf` here because the key types will
285			// always be primitive types so we don't need to decode it using the standard pattern
286			keyValue := reflect.ValueOf(key.Value())
287			if !keyValue.IsValid() {
288				if err == nil {
289					err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String())
290				}
291				return false
292			}
293			if keyValue.Type() != keyType {
294				if err == nil {
295					err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type())
296				}
297				return false
298			}
299
300			itemValue := reflect.New(itemType).Elem()
301			itemerr := itemDecoder(value, itemValue, state)
302			if itemerr != nil {
303				if err == nil {
304					err = itemerr
305				}
306				return false
307			}
308
309			mapValue.SetMapIndex(keyValue, itemValue)
310			return true
311		})
312
313		if err != nil {
314			return err
315		}
316		value.Set(mapValue)
317		return nil
318	}
319}
320
321func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc {
322	itemDecoder := d.typeDecoder(t.Elem())
323
324	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
325		if !node.IsArray() {
326			return fmt.Errorf("apijson: could not deserialize to an array")
327		}
328
329		arrayNode := node.Array()
330
331		arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode))
332		for i, itemNode := range arrayNode {
333			err = itemDecoder(itemNode, arrayValue.Index(i), state)
334			if err != nil {
335				return err
336			}
337		}
338
339		value.Set(arrayValue)
340		return nil
341	}
342}
343
344func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
345	// map of json field name to struct field decoders
346	decoderFields := map[string]decoderField{}
347	anonymousDecoders := []decoderField{}
348	extraDecoder := (*decoderField)(nil)
349	var inlineDecoders []decoderField
350
351	for i := 0; i < t.NumField(); i++ {
352		idx := []int{i}
353		field := t.FieldByIndex(idx)
354		if !field.IsExported() {
355			continue
356		}
357		// If this is an embedded struct, traverse one level deeper to extract
358		// the fields and get their encoders as well.
359		if field.Anonymous {
360			anonymousDecoders = append(anonymousDecoders, decoderField{
361				fn:  d.typeDecoder(field.Type),
362				idx: idx[:],
363			})
364			continue
365		}
366		// If json tag is not present, then we skip, which is intentionally
367		// different behavior from the stdlib.
368		ptag, ok := parseJSONStructTag(field)
369		if !ok {
370			continue
371		}
372		// We only want to support unexported fields if they're tagged with
373		// `extras` because that field shouldn't be part of the public API.
374		if ptag.extras {
375			extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name}
376			continue
377		}
378		if ptag.inline {
379			df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
380			inlineDecoders = append(inlineDecoders, df)
381			continue
382		}
383		if ptag.metadata {
384			continue
385		}
386
387		oldFormat := d.dateFormat
388		dateFormat, ok := parseFormatStructTag(field)
389		if ok {
390			switch dateFormat {
391			case "date-time":
392				d.dateFormat = time.RFC3339
393			case "date":
394				d.dateFormat = "2006-01-02"
395			}
396		}
397		decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
398		d.dateFormat = oldFormat
399	}
400
401	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
402		if field := value.FieldByName("JSON"); field.IsValid() {
403			if raw := field.FieldByName("raw"); raw.IsValid() {
404				setUnexportedField(raw, node.Raw)
405			}
406		}
407
408		for _, decoder := range anonymousDecoders {
409			// ignore errors
410			decoder.fn(node, value.FieldByIndex(decoder.idx), state)
411		}
412
413		for _, inlineDecoder := range inlineDecoders {
414			var meta Field
415			dest := value.FieldByIndex(inlineDecoder.idx)
416			isValid := false
417			if dest.IsValid() && node.Type != gjson.Null {
418				inlineState := decoderState{exactness: state.exactness, strict: true}
419				err = inlineDecoder.fn(node, dest, &inlineState)
420				if err == nil {
421					isValid = true
422				}
423			}
424
425			if node.Type == gjson.Null {
426				meta = Field{
427					raw:    node.Raw,
428					status: null,
429				}
430			} else if !isValid {
431				// If an inline decoder fails, unset the field and move on.
432				if dest.IsValid() {
433					dest.SetZero()
434				}
435				continue
436			} else if isValid {
437				meta = Field{
438					raw:    node.Raw,
439					status: valid,
440				}
441			}
442			setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta)
443		}
444
445		typedExtraType := reflect.Type(nil)
446		typedExtraFields := reflect.Value{}
447		if extraDecoder != nil {
448			typedExtraType = value.FieldByIndex(extraDecoder.idx).Type()
449			typedExtraFields = reflect.MakeMap(typedExtraType)
450		}
451		untypedExtraFields := map[string]Field{}
452
453		for fieldName, itemNode := range node.Map() {
454			df, explicit := decoderFields[fieldName]
455			var (
456				dest reflect.Value
457				fn   decoderFunc
458				meta Field
459			)
460			if explicit {
461				fn = df.fn
462				dest = value.FieldByIndex(df.idx)
463			}
464			if !explicit && extraDecoder != nil {
465				dest = reflect.New(typedExtraType.Elem()).Elem()
466				fn = extraDecoder.fn
467			}
468
469			isValid := false
470			if dest.IsValid() && itemNode.Type != gjson.Null {
471				err = fn(itemNode, dest, state)
472				if err == nil {
473					isValid = true
474				}
475			}
476
477			if itemNode.Type == gjson.Null {
478				meta = Field{
479					raw:    itemNode.Raw,
480					status: null,
481				}
482			} else if !isValid {
483				meta = Field{
484					raw:    itemNode.Raw,
485					status: invalid,
486				}
487			} else if isValid {
488				meta = Field{
489					raw:    itemNode.Raw,
490					status: valid,
491				}
492			}
493
494			if explicit {
495				setMetadataSubField(value, df.idx, df.goname, meta)
496			}
497			if !explicit {
498				untypedExtraFields[fieldName] = meta
499			}
500			if !explicit && extraDecoder != nil {
501				typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest)
502			}
503		}
504
505		if extraDecoder != nil && typedExtraFields.Len() > 0 {
506			value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields)
507		}
508
509		// Set exactness to 'extras' if there are untyped, extra fields.
510		if len(untypedExtraFields) > 0 && state.exactness > extras {
511			state.exactness = extras
512		}
513
514		if len(untypedExtraFields) > 0 {
515			setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields)
516		}
517		return nil
518	}
519}
520
521func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
522	switch t.Kind() {
523	case reflect.String:
524		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
525			v.SetString(n.String())
526			if guardStrict(state, n.Type != gjson.String) {
527				return fmt.Errorf("apijson: failed to parse string strictly")
528			}
529			// Everything that is not an object can be loosely stringified.
530			if n.Type == gjson.JSON {
531				return fmt.Errorf("apijson: failed to parse string")
532			}
533			if guardUnknown(state, v) {
534				return fmt.Errorf("apijson: failed string enum validation")
535			}
536			return nil
537		}
538	case reflect.Bool:
539		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
540			v.SetBool(n.Bool())
541			if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) {
542				return fmt.Errorf("apijson: failed to parse bool strictly")
543			}
544			// Numbers and strings that are either 'true' or 'false' can be loosely
545			// deserialized as bool.
546			if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON {
547				return fmt.Errorf("apijson: failed to parse bool")
548			}
549			if guardUnknown(state, v) {
550				return fmt.Errorf("apijson: failed bool enum validation")
551			}
552			return nil
553		}
554	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
555		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
556			v.SetInt(n.Int())
557			if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) {
558				return fmt.Errorf("apijson: failed to parse int strictly")
559			}
560			// Numbers, booleans, and strings that maybe look like numbers can be
561			// loosely deserialized as numbers.
562			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
563				return fmt.Errorf("apijson: failed to parse int")
564			}
565			if guardUnknown(state, v) {
566				return fmt.Errorf("apijson: failed int enum validation")
567			}
568			return nil
569		}
570	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
571		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
572			v.SetUint(n.Uint())
573			if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) {
574				return fmt.Errorf("apijson: failed to parse uint strictly")
575			}
576			// Numbers, booleans, and strings that maybe look like numbers can be
577			// loosely deserialized as uint.
578			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
579				return fmt.Errorf("apijson: failed to parse uint")
580			}
581			if guardUnknown(state, v) {
582				return fmt.Errorf("apijson: failed uint enum validation")
583			}
584			return nil
585		}
586	case reflect.Float32, reflect.Float64:
587		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
588			v.SetFloat(n.Float())
589			if guardStrict(state, n.Type != gjson.Number) {
590				return fmt.Errorf("apijson: failed to parse float strictly")
591			}
592			// Numbers, booleans, and strings that maybe look like numbers can be
593			// loosely deserialized as floats.
594			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
595				return fmt.Errorf("apijson: failed to parse float")
596			}
597			if guardUnknown(state, v) {
598				return fmt.Errorf("apijson: failed float enum validation")
599			}
600			return nil
601		}
602	default:
603		return func(node gjson.Result, v reflect.Value, state *decoderState) error {
604			return fmt.Errorf("unknown type received at primitive decoder: %s", t.String())
605		}
606	}
607}
608
609func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
610	format := d.dateFormat
611	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
612		parsed, err := time.Parse(format, n.Str)
613		if err == nil {
614			v.Set(reflect.ValueOf(parsed).Convert(t))
615			return nil
616		}
617
618		if guardStrict(state, true) {
619			return err
620		}
621
622		layouts := []string{
623			"2006-01-02",
624			"2006-01-02T15:04:05Z07:00",
625			"2006-01-02T15:04:05Z0700",
626			"2006-01-02T15:04:05",
627			"2006-01-02 15:04:05Z07:00",
628			"2006-01-02 15:04:05Z0700",
629			"2006-01-02 15:04:05",
630		}
631
632		for _, layout := range layouts {
633			parsed, err := time.Parse(layout, n.Str)
634			if err == nil {
635				v.Set(reflect.ValueOf(parsed).Convert(t))
636				return nil
637			}
638		}
639
640		return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str)
641	}
642}
643
644func setUnexportedField(field reflect.Value, value interface{}) {
645	reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value))
646}
647
648func guardStrict(state *decoderState, cond bool) bool {
649	if !cond {
650		return false
651	}
652
653	if state.strict {
654		return true
655	}
656
657	state.exactness = loose
658	return false
659}
660
661func canParseAsNumber(str string) bool {
662	_, err := strconv.ParseFloat(str, 64)
663	return err == nil
664}
665
666var stringType = reflect.TypeOf(string(""))
667
668func guardUnknown(state *decoderState, v reflect.Value) bool {
669	if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) {
670		return true
671	}
672
673	constantString, ok := v.Interface().(interface{ Default() string })
674	named := v.Type() != stringType
675	if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) {
676		return true
677	}
678	return false
679}