1package apijson
2
3import (
4 "encoding/json"
5 "errors"
6 "fmt"
7 "reflect"
8 "strconv"
9 "sync"
10 "time"
11 "unsafe"
12
13 "github.com/tidwall/gjson"
14)
15
16// decoders is a synchronized map with roughly the following type:
17// map[reflect.Type]decoderFunc
18var decoders sync.Map
19
20// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded
21// data and stores it in the given pointer.
22func Unmarshal(raw []byte, to any) error {
23 d := &decoderBuilder{dateFormat: time.RFC3339}
24 return d.unmarshal(raw, to)
25}
26
27// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the
28// root element. Useful if a struct's UnmarshalJSON is overrode to use the
29// behavior of this encoder versus the standard library.
30func UnmarshalRoot(raw []byte, to any) error {
31 d := &decoderBuilder{dateFormat: time.RFC3339, root: true}
32 return d.unmarshal(raw, to)
33}
34
35// decoderBuilder contains the 'compile-time' state of the decoder.
36type decoderBuilder struct {
37 // Whether or not this is the first element and called by [UnmarshalRoot], see
38 // the documentation there to see why this is necessary.
39 root bool
40 // The dateFormat (a format string for [time.Format]) which is chosen by the
41 // last struct tag that was seen.
42 dateFormat string
43}
44
45// decoderState contains the 'run-time' state of the decoder.
46type decoderState struct {
47 strict bool
48 exactness exactness
49}
50
51// Exactness refers to how close to the type the result was if deserialization
52// was successful. This is useful in deserializing unions, where you want to try
53// each entry, first with strict, then with looser validation, without actually
54// having to do a lot of redundant work by marshalling twice (or maybe even more
55// times).
56type exactness int8
57
58const (
59 // Some values had to fudged a bit, for example by converting a string to an
60 // int, or an enum with extra values.
61 loose exactness = iota
62 // There are some extra arguments, but other wise it matches the union.
63 extras
64 // Exactly right.
65 exact
66)
67
68type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error
69
70type decoderField struct {
71 tag parsedStructTag
72 fn decoderFunc
73 idx []int
74 goname string
75}
76
77type decoderEntry struct {
78 reflect.Type
79 dateFormat string
80 root bool
81}
82
83func (d *decoderBuilder) unmarshal(raw []byte, to any) error {
84 value := reflect.ValueOf(to).Elem()
85 result := gjson.ParseBytes(raw)
86 if !value.IsValid() {
87 return fmt.Errorf("apijson: cannot marshal into invalid value")
88 }
89 return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact})
90}
91
92func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
93 entry := decoderEntry{
94 Type: t,
95 dateFormat: d.dateFormat,
96 root: d.root,
97 }
98
99 if fi, ok := decoders.Load(entry); ok {
100 return fi.(decoderFunc)
101 }
102
103 // To deal with recursive types, populate the map with an
104 // indirect func before we build it. This type waits on the
105 // real func (f) to be ready and then calls it. This indirect
106 // func is only used for recursive types.
107 var (
108 wg sync.WaitGroup
109 f decoderFunc
110 )
111 wg.Add(1)
112 fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error {
113 wg.Wait()
114 return f(node, v, state)
115 }))
116 if loaded {
117 return fi.(decoderFunc)
118 }
119
120 // Compute the real decoder and replace the indirect func with it.
121 f = d.newTypeDecoder(t)
122 wg.Done()
123 decoders.Store(entry, f)
124 return f
125}
126
127func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
128 return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
129}
130
131func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
132 if v.Kind() == reflect.Pointer && v.CanSet() {
133 v.Set(reflect.New(v.Type().Elem()))
134 }
135 return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
136}
137
138func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
139 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
140 return d.newTimeTypeDecoder(t)
141 }
142 if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
143 return unmarshalerDecoder
144 }
145 if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
146 if _, ok := unionVariants[t]; !ok {
147 return indirectUnmarshalerDecoder
148 }
149 }
150 d.root = false
151
152 if _, ok := unionRegistry[t]; ok {
153 return d.newUnionDecoder(t)
154 }
155
156 switch t.Kind() {
157 case reflect.Pointer:
158 inner := t.Elem()
159 innerDecoder := d.typeDecoder(inner)
160
161 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
162 if !v.IsValid() {
163 return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v)
164 }
165
166 newValue := reflect.New(inner).Elem()
167 err := innerDecoder(n, newValue, state)
168 if err != nil {
169 return err
170 }
171
172 v.Set(newValue.Addr())
173 return nil
174 }
175 case reflect.Struct:
176 if isEmbeddedUnion(t) {
177 return d.newEmbeddedUnionDecoder(t)
178 }
179 return d.newStructTypeDecoder(t)
180 case reflect.Array:
181 fallthrough
182 case reflect.Slice:
183 return d.newArrayTypeDecoder(t)
184 case reflect.Map:
185 return d.newMapDecoder(t)
186 case reflect.Interface:
187 return func(node gjson.Result, value reflect.Value, state *decoderState) error {
188 if !value.IsValid() {
189 return fmt.Errorf("apijson: unexpected invalid value %+#v", value)
190 }
191 if node.Value() != nil && value.CanSet() {
192 value.Set(reflect.ValueOf(node.Value()))
193 }
194 return nil
195 }
196 default:
197 return d.newPrimitiveTypeDecoder(t)
198 }
199}
200
201// newUnionDecoder returns a decoderFunc that deserializes into a union using an
202// algorithm roughly similar to Pydantic's [smart algorithm].
203//
204// Conceptually this is equivalent to choosing the best schema based on how 'exact'
205// the deserialization is for each of the schemas.
206//
207// If there is a tie in the level of exactness, then the tie is broken
208// left-to-right.
209//
210// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
211func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
212 unionEntry, ok := unionRegistry[t]
213 if !ok {
214 panic("apijson: couldn't find union of type " + t.String() + " in union registry")
215 }
216 decoders := []decoderFunc{}
217 for _, variant := range unionEntry.variants {
218 decoder := d.typeDecoder(variant.Type)
219 decoders = append(decoders, decoder)
220 }
221 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
222 // If there is a discriminator match, circumvent the exactness logic entirely
223 for idx, variant := range unionEntry.variants {
224 decoder := decoders[idx]
225 if variant.TypeFilter != n.Type {
226 continue
227 }
228
229 if len(unionEntry.discriminatorKey) != 0 {
230 discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
231 if discriminatorValue == variant.DiscriminatorValue {
232 inner := reflect.New(variant.Type).Elem()
233 err := decoder(n, inner, state)
234 v.Set(inner)
235 return err
236 }
237 }
238 }
239
240 // Set bestExactness to worse than loose
241 bestExactness := loose - 1
242 for idx, variant := range unionEntry.variants {
243 decoder := decoders[idx]
244 if variant.TypeFilter != n.Type {
245 continue
246 }
247 sub := decoderState{strict: state.strict, exactness: exact}
248 inner := reflect.New(variant.Type).Elem()
249 err := decoder(n, inner, &sub)
250 if err != nil {
251 continue
252 }
253 if sub.exactness == exact {
254 v.Set(inner)
255 return nil
256 }
257 if sub.exactness > bestExactness {
258 v.Set(inner)
259 bestExactness = sub.exactness
260 }
261 }
262
263 if bestExactness < loose {
264 return errors.New("apijson: was not able to coerce type as union")
265 }
266
267 if guardStrict(state, bestExactness != exact) {
268 return errors.New("apijson: was not able to coerce type as union strictly")
269 }
270
271 return nil
272 }
273}
274
275func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc {
276 keyType := t.Key()
277 itemType := t.Elem()
278 itemDecoder := d.typeDecoder(itemType)
279
280 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
281 mapValue := reflect.MakeMapWithSize(t, len(node.Map()))
282
283 node.ForEach(func(key, value gjson.Result) bool {
284 // It's fine for us to just use `ValueOf` here because the key types will
285 // always be primitive types so we don't need to decode it using the standard pattern
286 keyValue := reflect.ValueOf(key.Value())
287 if !keyValue.IsValid() {
288 if err == nil {
289 err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String())
290 }
291 return false
292 }
293 if keyValue.Type() != keyType {
294 if err == nil {
295 err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type())
296 }
297 return false
298 }
299
300 itemValue := reflect.New(itemType).Elem()
301 itemerr := itemDecoder(value, itemValue, state)
302 if itemerr != nil {
303 if err == nil {
304 err = itemerr
305 }
306 return false
307 }
308
309 mapValue.SetMapIndex(keyValue, itemValue)
310 return true
311 })
312
313 if err != nil {
314 return err
315 }
316 value.Set(mapValue)
317 return nil
318 }
319}
320
321func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc {
322 itemDecoder := d.typeDecoder(t.Elem())
323
324 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
325 if !node.IsArray() {
326 return fmt.Errorf("apijson: could not deserialize to an array")
327 }
328
329 arrayNode := node.Array()
330
331 arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode))
332 for i, itemNode := range arrayNode {
333 err = itemDecoder(itemNode, arrayValue.Index(i), state)
334 if err != nil {
335 return err
336 }
337 }
338
339 value.Set(arrayValue)
340 return nil
341 }
342}
343
344func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
345 // map of json field name to struct field decoders
346 decoderFields := map[string]decoderField{}
347 anonymousDecoders := []decoderField{}
348 extraDecoder := (*decoderField)(nil)
349 var inlineDecoders []decoderField
350
351 for i := 0; i < t.NumField(); i++ {
352 idx := []int{i}
353 field := t.FieldByIndex(idx)
354 if !field.IsExported() {
355 continue
356 }
357 // If this is an embedded struct, traverse one level deeper to extract
358 // the fields and get their encoders as well.
359 if field.Anonymous {
360 anonymousDecoders = append(anonymousDecoders, decoderField{
361 fn: d.typeDecoder(field.Type),
362 idx: idx[:],
363 })
364 continue
365 }
366 // If json tag is not present, then we skip, which is intentionally
367 // different behavior from the stdlib.
368 ptag, ok := parseJSONStructTag(field)
369 if !ok {
370 continue
371 }
372 // We only want to support unexported fields if they're tagged with
373 // `extras` because that field shouldn't be part of the public API.
374 if ptag.extras {
375 extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name}
376 continue
377 }
378 if ptag.inline {
379 df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
380 inlineDecoders = append(inlineDecoders, df)
381 continue
382 }
383 if ptag.metadata {
384 continue
385 }
386
387 oldFormat := d.dateFormat
388 dateFormat, ok := parseFormatStructTag(field)
389 if ok {
390 switch dateFormat {
391 case "date-time":
392 d.dateFormat = time.RFC3339
393 case "date":
394 d.dateFormat = "2006-01-02"
395 }
396 }
397 decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
398 d.dateFormat = oldFormat
399 }
400
401 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
402 if field := value.FieldByName("JSON"); field.IsValid() {
403 if raw := field.FieldByName("raw"); raw.IsValid() {
404 setUnexportedField(raw, node.Raw)
405 }
406 }
407
408 for _, decoder := range anonymousDecoders {
409 // ignore errors
410 decoder.fn(node, value.FieldByIndex(decoder.idx), state)
411 }
412
413 for _, inlineDecoder := range inlineDecoders {
414 var meta Field
415 dest := value.FieldByIndex(inlineDecoder.idx)
416 isValid := false
417 if dest.IsValid() && node.Type != gjson.Null {
418 inlineState := decoderState{exactness: state.exactness, strict: true}
419 err = inlineDecoder.fn(node, dest, &inlineState)
420 if err == nil {
421 isValid = true
422 }
423 }
424
425 if node.Type == gjson.Null {
426 meta = Field{
427 raw: node.Raw,
428 status: null,
429 }
430 } else if !isValid {
431 // If an inline decoder fails, unset the field and move on.
432 if dest.IsValid() {
433 dest.SetZero()
434 }
435 continue
436 } else if isValid {
437 meta = Field{
438 raw: node.Raw,
439 status: valid,
440 }
441 }
442 setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta)
443 }
444
445 typedExtraType := reflect.Type(nil)
446 typedExtraFields := reflect.Value{}
447 if extraDecoder != nil {
448 typedExtraType = value.FieldByIndex(extraDecoder.idx).Type()
449 typedExtraFields = reflect.MakeMap(typedExtraType)
450 }
451 untypedExtraFields := map[string]Field{}
452
453 for fieldName, itemNode := range node.Map() {
454 df, explicit := decoderFields[fieldName]
455 var (
456 dest reflect.Value
457 fn decoderFunc
458 meta Field
459 )
460 if explicit {
461 fn = df.fn
462 dest = value.FieldByIndex(df.idx)
463 }
464 if !explicit && extraDecoder != nil {
465 dest = reflect.New(typedExtraType.Elem()).Elem()
466 fn = extraDecoder.fn
467 }
468
469 isValid := false
470 if dest.IsValid() && itemNode.Type != gjson.Null {
471 err = fn(itemNode, dest, state)
472 if err == nil {
473 isValid = true
474 }
475 }
476
477 if itemNode.Type == gjson.Null {
478 meta = Field{
479 raw: itemNode.Raw,
480 status: null,
481 }
482 } else if !isValid {
483 meta = Field{
484 raw: itemNode.Raw,
485 status: invalid,
486 }
487 } else if isValid {
488 meta = Field{
489 raw: itemNode.Raw,
490 status: valid,
491 }
492 }
493
494 if explicit {
495 setMetadataSubField(value, df.idx, df.goname, meta)
496 }
497 if !explicit {
498 untypedExtraFields[fieldName] = meta
499 }
500 if !explicit && extraDecoder != nil {
501 typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest)
502 }
503 }
504
505 if extraDecoder != nil && typedExtraFields.Len() > 0 {
506 value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields)
507 }
508
509 // Set exactness to 'extras' if there are untyped, extra fields.
510 if len(untypedExtraFields) > 0 && state.exactness > extras {
511 state.exactness = extras
512 }
513
514 if len(untypedExtraFields) > 0 {
515 setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields)
516 }
517 return nil
518 }
519}
520
521func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
522 switch t.Kind() {
523 case reflect.String:
524 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
525 v.SetString(n.String())
526 if guardStrict(state, n.Type != gjson.String) {
527 return fmt.Errorf("apijson: failed to parse string strictly")
528 }
529 // Everything that is not an object can be loosely stringified.
530 if n.Type == gjson.JSON {
531 return fmt.Errorf("apijson: failed to parse string")
532 }
533 if guardUnknown(state, v) {
534 return fmt.Errorf("apijson: failed string enum validation")
535 }
536 return nil
537 }
538 case reflect.Bool:
539 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
540 v.SetBool(n.Bool())
541 if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) {
542 return fmt.Errorf("apijson: failed to parse bool strictly")
543 }
544 // Numbers and strings that are either 'true' or 'false' can be loosely
545 // deserialized as bool.
546 if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON {
547 return fmt.Errorf("apijson: failed to parse bool")
548 }
549 if guardUnknown(state, v) {
550 return fmt.Errorf("apijson: failed bool enum validation")
551 }
552 return nil
553 }
554 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
555 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
556 v.SetInt(n.Int())
557 if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) {
558 return fmt.Errorf("apijson: failed to parse int strictly")
559 }
560 // Numbers, booleans, and strings that maybe look like numbers can be
561 // loosely deserialized as numbers.
562 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
563 return fmt.Errorf("apijson: failed to parse int")
564 }
565 if guardUnknown(state, v) {
566 return fmt.Errorf("apijson: failed int enum validation")
567 }
568 return nil
569 }
570 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
571 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
572 v.SetUint(n.Uint())
573 if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) {
574 return fmt.Errorf("apijson: failed to parse uint strictly")
575 }
576 // Numbers, booleans, and strings that maybe look like numbers can be
577 // loosely deserialized as uint.
578 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
579 return fmt.Errorf("apijson: failed to parse uint")
580 }
581 if guardUnknown(state, v) {
582 return fmt.Errorf("apijson: failed uint enum validation")
583 }
584 return nil
585 }
586 case reflect.Float32, reflect.Float64:
587 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
588 v.SetFloat(n.Float())
589 if guardStrict(state, n.Type != gjson.Number) {
590 return fmt.Errorf("apijson: failed to parse float strictly")
591 }
592 // Numbers, booleans, and strings that maybe look like numbers can be
593 // loosely deserialized as floats.
594 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
595 return fmt.Errorf("apijson: failed to parse float")
596 }
597 if guardUnknown(state, v) {
598 return fmt.Errorf("apijson: failed float enum validation")
599 }
600 return nil
601 }
602 default:
603 return func(node gjson.Result, v reflect.Value, state *decoderState) error {
604 return fmt.Errorf("unknown type received at primitive decoder: %s", t.String())
605 }
606 }
607}
608
609func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
610 format := d.dateFormat
611 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
612 parsed, err := time.Parse(format, n.Str)
613 if err == nil {
614 v.Set(reflect.ValueOf(parsed).Convert(t))
615 return nil
616 }
617
618 if guardStrict(state, true) {
619 return err
620 }
621
622 layouts := []string{
623 "2006-01-02",
624 "2006-01-02T15:04:05Z07:00",
625 "2006-01-02T15:04:05Z0700",
626 "2006-01-02T15:04:05",
627 "2006-01-02 15:04:05Z07:00",
628 "2006-01-02 15:04:05Z0700",
629 "2006-01-02 15:04:05",
630 }
631
632 for _, layout := range layouts {
633 parsed, err := time.Parse(layout, n.Str)
634 if err == nil {
635 v.Set(reflect.ValueOf(parsed).Convert(t))
636 return nil
637 }
638 }
639
640 return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str)
641 }
642}
643
644func setUnexportedField(field reflect.Value, value interface{}) {
645 reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value))
646}
647
648func guardStrict(state *decoderState, cond bool) bool {
649 if !cond {
650 return false
651 }
652
653 if state.strict {
654 return true
655 }
656
657 state.exactness = loose
658 return false
659}
660
661func canParseAsNumber(str string) bool {
662 _, err := strconv.ParseFloat(str, 64)
663 return err == nil
664}
665
666var stringType = reflect.TypeOf(string(""))
667
668func guardUnknown(state *decoderState, v reflect.Value) bool {
669 if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) {
670 return true
671 }
672
673 constantString, ok := v.Interface().(interface{ Default() string })
674 named := v.Type() != stringType
675 if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) {
676 return true
677 }
678 return false
679}