1// The deserialization algorithm from apijson may be subject to improvements
2// between minor versions, particularly with respect to calling [json.Unmarshal]
3// into param unions.
4
5package apijson
6
7import (
8 "encoding/json"
9 "fmt"
10 "github.com/openai/openai-go/packages/param"
11 "reflect"
12 "strconv"
13 "sync"
14 "time"
15 "unsafe"
16
17 "github.com/tidwall/gjson"
18)
19
20// decoders is a synchronized map with roughly the following type:
21// map[reflect.Type]decoderFunc
22var decoders sync.Map
23
24// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded
25// data and stores it in the given pointer.
26func Unmarshal(raw []byte, to any) error {
27 d := &decoderBuilder{dateFormat: time.RFC3339}
28 return d.unmarshal(raw, to)
29}
30
31// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the
32// root element. Useful if a struct's UnmarshalJSON is overrode to use the
33// behavior of this encoder versus the standard library.
34func UnmarshalRoot(raw []byte, to any) error {
35 d := &decoderBuilder{dateFormat: time.RFC3339, root: true}
36 return d.unmarshal(raw, to)
37}
38
39// decoderBuilder contains the 'compile-time' state of the decoder.
40type decoderBuilder struct {
41 // Whether or not this is the first element and called by [UnmarshalRoot], see
42 // the documentation there to see why this is necessary.
43 root bool
44 // The dateFormat (a format string for [time.Format]) which is chosen by the
45 // last struct tag that was seen.
46 dateFormat string
47}
48
49// decoderState contains the 'run-time' state of the decoder.
50type decoderState struct {
51 strict bool
52 exactness exactness
53 validator *validationEntry
54}
55
56// Exactness refers to how close to the type the result was if deserialization
57// was successful. This is useful in deserializing unions, where you want to try
58// each entry, first with strict, then with looser validation, without actually
59// having to do a lot of redundant work by marshalling twice (or maybe even more
60// times).
61type exactness int8
62
63const (
64 // Some values had to fudged a bit, for example by converting a string to an
65 // int, or an enum with extra values.
66 loose exactness = iota
67 // There are some extra arguments, but other wise it matches the union.
68 extras
69 // Exactly right.
70 exact
71)
72
73type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error
74
75type decoderField struct {
76 tag parsedStructTag
77 fn decoderFunc
78 idx []int
79 goname string
80}
81
82type decoderEntry struct {
83 reflect.Type
84 dateFormat string
85 root bool
86}
87
88func (d *decoderBuilder) unmarshal(raw []byte, to any) error {
89 value := reflect.ValueOf(to).Elem()
90 result := gjson.ParseBytes(raw)
91 if !value.IsValid() {
92 return fmt.Errorf("apijson: cannot marshal into invalid value")
93 }
94 return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact})
95}
96
97// unmarshalWithExactness is used for internal testing purposes.
98func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) {
99 value := reflect.ValueOf(to).Elem()
100 result := gjson.ParseBytes(raw)
101 if !value.IsValid() {
102 return 0, fmt.Errorf("apijson: cannot marshal into invalid value")
103 }
104 state := decoderState{strict: false, exactness: exact}
105 err := d.typeDecoder(value.Type())(result, value, &state)
106 return state.exactness, err
107}
108
109func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
110 entry := decoderEntry{
111 Type: t,
112 dateFormat: d.dateFormat,
113 root: d.root,
114 }
115
116 if fi, ok := decoders.Load(entry); ok {
117 return fi.(decoderFunc)
118 }
119
120 // To deal with recursive types, populate the map with an
121 // indirect func before we build it. This type waits on the
122 // real func (f) to be ready and then calls it. This indirect
123 // func is only used for recursive types.
124 var (
125 wg sync.WaitGroup
126 f decoderFunc
127 )
128 wg.Add(1)
129 fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error {
130 wg.Wait()
131 return f(node, v, state)
132 }))
133 if loaded {
134 return fi.(decoderFunc)
135 }
136
137 // Compute the real decoder and replace the indirect func with it.
138 f = d.newTypeDecoder(t)
139 wg.Done()
140 decoders.Store(entry, f)
141 return f
142}
143
144// validatedTypeDecoder wraps the type decoder with a validator. This is helpful
145// for ensuring that enum fields are correct.
146func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc {
147 dec := d.typeDecoder(t)
148 if entry == nil {
149 return dec
150 }
151
152 // Thread the current validation entry through the decoder,
153 // but clean up in time for the next field.
154 return func(node gjson.Result, v reflect.Value, state *decoderState) error {
155 state.validator = entry
156 err := dec(node, v, state)
157 state.validator = nil
158 return err
159 }
160}
161
162func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
163 return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
164}
165
166func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
167 if v.Kind() == reflect.Pointer && v.CanSet() {
168 v.Set(reflect.New(v.Type().Elem()))
169 }
170 return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
171}
172
173func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
174 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
175 return d.newTimeTypeDecoder(t)
176 }
177
178 if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
179 return d.newOptTypeDecoder(t)
180 }
181
182 if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
183 return unmarshalerDecoder
184 }
185 if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
186 if _, ok := unionVariants[t]; !ok {
187 return indirectUnmarshalerDecoder
188 }
189 }
190 d.root = false
191
192 if _, ok := unionRegistry[t]; ok {
193 if isStructUnion(t) {
194 return d.newStructUnionDecoder(t)
195 }
196 return d.newUnionDecoder(t)
197 }
198
199 switch t.Kind() {
200 case reflect.Pointer:
201 inner := t.Elem()
202 innerDecoder := d.typeDecoder(inner)
203
204 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
205 if !v.IsValid() {
206 return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v)
207 }
208
209 newValue := reflect.New(inner).Elem()
210 err := innerDecoder(n, newValue, state)
211 if err != nil {
212 return err
213 }
214
215 v.Set(newValue.Addr())
216 return nil
217 }
218 case reflect.Struct:
219 if isStructUnion(t) {
220 return d.newStructUnionDecoder(t)
221 }
222 return d.newStructTypeDecoder(t)
223 case reflect.Array:
224 fallthrough
225 case reflect.Slice:
226 return d.newArrayTypeDecoder(t)
227 case reflect.Map:
228 return d.newMapDecoder(t)
229 case reflect.Interface:
230 return func(node gjson.Result, value reflect.Value, state *decoderState) error {
231 if !value.IsValid() {
232 return fmt.Errorf("apijson: unexpected invalid value %+#v", value)
233 }
234 if node.Value() != nil && value.CanSet() {
235 value.Set(reflect.ValueOf(node.Value()))
236 }
237 return nil
238 }
239 default:
240 return d.newPrimitiveTypeDecoder(t)
241 }
242}
243
244func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc {
245 keyType := t.Key()
246 itemType := t.Elem()
247 itemDecoder := d.typeDecoder(itemType)
248
249 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
250 mapValue := reflect.MakeMapWithSize(t, len(node.Map()))
251
252 node.ForEach(func(key, value gjson.Result) bool {
253 // It's fine for us to just use `ValueOf` here because the key types will
254 // always be primitive types so we don't need to decode it using the standard pattern
255 keyValue := reflect.ValueOf(key.Value())
256 if !keyValue.IsValid() {
257 if err == nil {
258 err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String())
259 }
260 return false
261 }
262 if keyValue.Type() != keyType {
263 if err == nil {
264 err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type())
265 }
266 return false
267 }
268
269 itemValue := reflect.New(itemType).Elem()
270 itemerr := itemDecoder(value, itemValue, state)
271 if itemerr != nil {
272 if err == nil {
273 err = itemerr
274 }
275 return false
276 }
277
278 mapValue.SetMapIndex(keyValue, itemValue)
279 return true
280 })
281
282 if err != nil {
283 return err
284 }
285 value.Set(mapValue)
286 return nil
287 }
288}
289
290func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc {
291 itemDecoder := d.typeDecoder(t.Elem())
292
293 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
294 if !node.IsArray() {
295 return fmt.Errorf("apijson: could not deserialize to an array")
296 }
297
298 arrayNode := node.Array()
299
300 arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode))
301 for i, itemNode := range arrayNode {
302 err = itemDecoder(itemNode, arrayValue.Index(i), state)
303 if err != nil {
304 return err
305 }
306 }
307
308 value.Set(arrayValue)
309 return nil
310 }
311}
312
313func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
314 // map of json field name to struct field decoders
315 decoderFields := map[string]decoderField{}
316 anonymousDecoders := []decoderField{}
317 extraDecoder := (*decoderField)(nil)
318 var inlineDecoders []decoderField
319
320 validationEntries := validationRegistry[t]
321
322 for i := 0; i < t.NumField(); i++ {
323 idx := []int{i}
324 field := t.FieldByIndex(idx)
325 if !field.IsExported() {
326 continue
327 }
328
329 var validator *validationEntry
330 for _, entry := range validationEntries {
331 if entry.field.Offset == field.Offset {
332 validator = &entry
333 break
334 }
335 }
336
337 // If this is an embedded struct, traverse one level deeper to extract
338 // the fields and get their encoders as well.
339 if field.Anonymous {
340 anonymousDecoders = append(anonymousDecoders, decoderField{
341 fn: d.typeDecoder(field.Type),
342 idx: idx[:],
343 })
344 continue
345 }
346 // If json tag is not present, then we skip, which is intentionally
347 // different behavior from the stdlib.
348 ptag, ok := parseJSONStructTag(field)
349 if !ok {
350 continue
351 }
352 // We only want to support unexported fields if they're tagged with
353 // `extras` because that field shouldn't be part of the public API.
354 if ptag.extras {
355 extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name}
356 continue
357 }
358 if ptag.inline {
359 df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
360 inlineDecoders = append(inlineDecoders, df)
361 continue
362 }
363 if ptag.metadata {
364 continue
365 }
366
367 oldFormat := d.dateFormat
368 dateFormat, ok := parseFormatStructTag(field)
369 if ok {
370 switch dateFormat {
371 case "date-time":
372 d.dateFormat = time.RFC3339
373 case "date":
374 d.dateFormat = "2006-01-02"
375 }
376 }
377
378 decoderFields[ptag.name] = decoderField{
379 ptag,
380 d.validatedTypeDecoder(field.Type, validator),
381 idx, field.Name,
382 }
383
384 d.dateFormat = oldFormat
385 }
386
387 return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
388 if field := value.FieldByName("JSON"); field.IsValid() {
389 if raw := field.FieldByName("raw"); raw.IsValid() {
390 setUnexportedField(raw, node.Raw)
391 }
392 }
393
394 for _, decoder := range anonymousDecoders {
395 // ignore errors
396 decoder.fn(node, value.FieldByIndex(decoder.idx), state)
397 }
398
399 for _, inlineDecoder := range inlineDecoders {
400 var meta Field
401 dest := value.FieldByIndex(inlineDecoder.idx)
402 isValid := false
403 if dest.IsValid() && node.Type != gjson.Null {
404 inlineState := decoderState{exactness: state.exactness, strict: true}
405 err = inlineDecoder.fn(node, dest, &inlineState)
406 if err == nil {
407 isValid = true
408 }
409 }
410
411 if node.Type == gjson.Null {
412 meta = Field{
413 raw: node.Raw,
414 status: null,
415 }
416 } else if !isValid {
417 // If an inline decoder fails, unset the field and move on.
418 if dest.IsValid() {
419 dest.SetZero()
420 }
421 continue
422 } else if isValid {
423 meta = Field{
424 raw: node.Raw,
425 status: valid,
426 }
427 }
428 setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta)
429 }
430
431 typedExtraType := reflect.Type(nil)
432 typedExtraFields := reflect.Value{}
433 if extraDecoder != nil {
434 typedExtraType = value.FieldByIndex(extraDecoder.idx).Type()
435 typedExtraFields = reflect.MakeMap(typedExtraType)
436 }
437 untypedExtraFields := map[string]Field{}
438
439 for fieldName, itemNode := range node.Map() {
440 df, explicit := decoderFields[fieldName]
441 var (
442 dest reflect.Value
443 fn decoderFunc
444 meta Field
445 )
446 if explicit {
447 fn = df.fn
448 dest = value.FieldByIndex(df.idx)
449 }
450 if !explicit && extraDecoder != nil {
451 dest = reflect.New(typedExtraType.Elem()).Elem()
452 fn = extraDecoder.fn
453 }
454
455 isValid := false
456 if dest.IsValid() && itemNode.Type != gjson.Null {
457 err = fn(itemNode, dest, state)
458 if err == nil {
459 isValid = true
460 }
461 }
462
463 // Handle null [param.Opt]
464 if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
465 dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw))
466 continue
467 }
468
469 if itemNode.Type == gjson.Null {
470 meta = Field{
471 raw: itemNode.Raw,
472 status: null,
473 }
474 } else if !isValid {
475 meta = Field{
476 raw: itemNode.Raw,
477 status: invalid,
478 }
479 } else if isValid {
480 meta = Field{
481 raw: itemNode.Raw,
482 status: valid,
483 }
484 }
485
486 if explicit {
487 setMetadataSubField(value, df.idx, df.goname, meta)
488 }
489 if !explicit {
490 untypedExtraFields[fieldName] = meta
491 }
492 if !explicit && extraDecoder != nil {
493 typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest)
494 }
495 }
496
497 if extraDecoder != nil && typedExtraFields.Len() > 0 {
498 value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields)
499 }
500
501 // Set exactness to 'extras' if there are untyped, extra fields.
502 if len(untypedExtraFields) > 0 && state.exactness > extras {
503 state.exactness = extras
504 }
505
506 if len(untypedExtraFields) > 0 {
507 setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields)
508 }
509 return nil
510 }
511}
512
513func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
514 switch t.Kind() {
515 case reflect.String:
516 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
517 v.SetString(n.String())
518 if guardStrict(state, n.Type != gjson.String) {
519 return fmt.Errorf("apijson: failed to parse string strictly")
520 }
521 // Everything that is not an object can be loosely stringified.
522 if n.Type == gjson.JSON {
523 return fmt.Errorf("apijson: failed to parse string")
524 }
525
526 state.validateString(v)
527
528 if guardUnknown(state, v) {
529 return fmt.Errorf("apijson: failed string enum validation")
530 }
531 return nil
532 }
533 case reflect.Bool:
534 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
535 v.SetBool(n.Bool())
536 if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) {
537 return fmt.Errorf("apijson: failed to parse bool strictly")
538 }
539 // Numbers and strings that are either 'true' or 'false' can be loosely
540 // deserialized as bool.
541 if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON {
542 return fmt.Errorf("apijson: failed to parse bool")
543 }
544
545 state.validateBool(v)
546
547 if guardUnknown(state, v) {
548 return fmt.Errorf("apijson: failed bool enum validation")
549 }
550 return nil
551 }
552 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
553 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
554 v.SetInt(n.Int())
555 if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) {
556 return fmt.Errorf("apijson: failed to parse int strictly")
557 }
558 // Numbers, booleans, and strings that maybe look like numbers can be
559 // loosely deserialized as numbers.
560 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
561 return fmt.Errorf("apijson: failed to parse int")
562 }
563
564 state.validateInt(v)
565
566 if guardUnknown(state, v) {
567 return fmt.Errorf("apijson: failed int enum validation")
568 }
569 return nil
570 }
571 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
572 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
573 v.SetUint(n.Uint())
574 if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) {
575 return fmt.Errorf("apijson: failed to parse uint strictly")
576 }
577 // Numbers, booleans, and strings that maybe look like numbers can be
578 // loosely deserialized as uint.
579 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
580 return fmt.Errorf("apijson: failed to parse uint")
581 }
582 if guardUnknown(state, v) {
583 return fmt.Errorf("apijson: failed uint enum validation")
584 }
585 return nil
586 }
587 case reflect.Float32, reflect.Float64:
588 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
589 v.SetFloat(n.Float())
590 if guardStrict(state, n.Type != gjson.Number) {
591 return fmt.Errorf("apijson: failed to parse float strictly")
592 }
593 // Numbers, booleans, and strings that maybe look like numbers can be
594 // loosely deserialized as floats.
595 if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
596 return fmt.Errorf("apijson: failed to parse float")
597 }
598 if guardUnknown(state, v) {
599 return fmt.Errorf("apijson: failed float enum validation")
600 }
601 return nil
602 }
603 default:
604 return func(node gjson.Result, v reflect.Value, state *decoderState) error {
605 return fmt.Errorf("unknown type received at primitive decoder: %s", t.String())
606 }
607 }
608}
609
610func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc {
611 for t.Kind() == reflect.Pointer {
612 t = t.Elem()
613 }
614 valueField, _ := t.FieldByName("Value")
615 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
616 state.validateOptKind(n, valueField.Type)
617 return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
618 }
619}
620
621func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
622 format := d.dateFormat
623 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
624 parsed, err := time.Parse(format, n.Str)
625 if err == nil {
626 v.Set(reflect.ValueOf(parsed).Convert(t))
627 return nil
628 }
629
630 if guardStrict(state, true) {
631 return err
632 }
633
634 layouts := []string{
635 "2006-01-02",
636 "2006-01-02T15:04:05Z07:00",
637 "2006-01-02T15:04:05Z0700",
638 "2006-01-02T15:04:05",
639 "2006-01-02 15:04:05Z07:00",
640 "2006-01-02 15:04:05Z0700",
641 "2006-01-02 15:04:05",
642 }
643
644 for _, layout := range layouts {
645 parsed, err := time.Parse(layout, n.Str)
646 if err == nil {
647 v.Set(reflect.ValueOf(parsed).Convert(t))
648 return nil
649 }
650 }
651
652 return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str)
653 }
654}
655
656func setUnexportedField(field reflect.Value, value any) {
657 reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value))
658}
659
660func guardStrict(state *decoderState, cond bool) bool {
661 if !cond {
662 return false
663 }
664
665 if state.strict {
666 return true
667 }
668
669 state.exactness = loose
670 return false
671}
672
673func canParseAsNumber(str string) bool {
674 _, err := strconv.ParseFloat(str, 64)
675 return err == nil
676}
677
678var stringType = reflect.TypeOf(string(""))
679
680func guardUnknown(state *decoderState, v reflect.Value) bool {
681 if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) {
682 return true
683 }
684
685 constantString, ok := v.Interface().(interface{ Default() string })
686 named := v.Type() != stringType
687 if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) {
688 return true
689 }
690 return false
691}