1// The deserialization algorithm from apijson may be subject to improvements
  2// between minor versions, particularly with respect to calling [json.Unmarshal]
  3// into param unions.
  4
  5package apijson
  6
  7import (
  8	"encoding/json"
  9	"fmt"
 10	"github.com/openai/openai-go/packages/param"
 11	"reflect"
 12	"strconv"
 13	"sync"
 14	"time"
 15	"unsafe"
 16
 17	"github.com/tidwall/gjson"
 18)
 19
 20// decoders is a synchronized map with roughly the following type:
 21// map[reflect.Type]decoderFunc
 22var decoders sync.Map
 23
 24// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded
 25// data and stores it in the given pointer.
 26func Unmarshal(raw []byte, to any) error {
 27	d := &decoderBuilder{dateFormat: time.RFC3339}
 28	return d.unmarshal(raw, to)
 29}
 30
 31// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the
 32// root element. Useful if a struct's UnmarshalJSON is overrode to use the
 33// behavior of this encoder versus the standard library.
 34func UnmarshalRoot(raw []byte, to any) error {
 35	d := &decoderBuilder{dateFormat: time.RFC3339, root: true}
 36	return d.unmarshal(raw, to)
 37}
 38
 39// decoderBuilder contains the 'compile-time' state of the decoder.
 40type decoderBuilder struct {
 41	// Whether or not this is the first element and called by [UnmarshalRoot], see
 42	// the documentation there to see why this is necessary.
 43	root bool
 44	// The dateFormat (a format string for [time.Format]) which is chosen by the
 45	// last struct tag that was seen.
 46	dateFormat string
 47}
 48
 49// decoderState contains the 'run-time' state of the decoder.
 50type decoderState struct {
 51	strict    bool
 52	exactness exactness
 53	validator *validationEntry
 54}
 55
 56// Exactness refers to how close to the type the result was if deserialization
 57// was successful. This is useful in deserializing unions, where you want to try
 58// each entry, first with strict, then with looser validation, without actually
 59// having to do a lot of redundant work by marshalling twice (or maybe even more
 60// times).
 61type exactness int8
 62
 63const (
 64	// Some values had to fudged a bit, for example by converting a string to an
 65	// int, or an enum with extra values.
 66	loose exactness = iota
 67	// There are some extra arguments, but other wise it matches the union.
 68	extras
 69	// Exactly right.
 70	exact
 71)
 72
 73type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error
 74
 75type decoderField struct {
 76	tag    parsedStructTag
 77	fn     decoderFunc
 78	idx    []int
 79	goname string
 80}
 81
 82type decoderEntry struct {
 83	reflect.Type
 84	dateFormat string
 85	root       bool
 86}
 87
 88func (d *decoderBuilder) unmarshal(raw []byte, to any) error {
 89	value := reflect.ValueOf(to).Elem()
 90	result := gjson.ParseBytes(raw)
 91	if !value.IsValid() {
 92		return fmt.Errorf("apijson: cannot marshal into invalid value")
 93	}
 94	return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact})
 95}
 96
 97// unmarshalWithExactness is used for internal testing purposes.
 98func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) {
 99	value := reflect.ValueOf(to).Elem()
100	result := gjson.ParseBytes(raw)
101	if !value.IsValid() {
102		return 0, fmt.Errorf("apijson: cannot marshal into invalid value")
103	}
104	state := decoderState{strict: false, exactness: exact}
105	err := d.typeDecoder(value.Type())(result, value, &state)
106	return state.exactness, err
107}
108
109func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
110	entry := decoderEntry{
111		Type:       t,
112		dateFormat: d.dateFormat,
113		root:       d.root,
114	}
115
116	if fi, ok := decoders.Load(entry); ok {
117		return fi.(decoderFunc)
118	}
119
120	// To deal with recursive types, populate the map with an
121	// indirect func before we build it. This type waits on the
122	// real func (f) to be ready and then calls it. This indirect
123	// func is only used for recursive types.
124	var (
125		wg sync.WaitGroup
126		f  decoderFunc
127	)
128	wg.Add(1)
129	fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error {
130		wg.Wait()
131		return f(node, v, state)
132	}))
133	if loaded {
134		return fi.(decoderFunc)
135	}
136
137	// Compute the real decoder and replace the indirect func with it.
138	f = d.newTypeDecoder(t)
139	wg.Done()
140	decoders.Store(entry, f)
141	return f
142}
143
144// validatedTypeDecoder wraps the type decoder with a validator. This is helpful
145// for ensuring that enum fields are correct.
146func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc {
147	dec := d.typeDecoder(t)
148	if entry == nil {
149		return dec
150	}
151
152	// Thread the current validation entry through the decoder,
153	// but clean up in time for the next field.
154	return func(node gjson.Result, v reflect.Value, state *decoderState) error {
155		state.validator = entry
156		err := dec(node, v, state)
157		state.validator = nil
158		return err
159	}
160}
161
162func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
163	return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
164}
165
166func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
167	if v.Kind() == reflect.Pointer && v.CanSet() {
168		v.Set(reflect.New(v.Type().Elem()))
169	}
170	return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
171}
172
173func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
174	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
175		return d.newTimeTypeDecoder(t)
176	}
177
178	if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
179		return d.newOptTypeDecoder(t)
180	}
181
182	if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
183		return unmarshalerDecoder
184	}
185	if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
186		if _, ok := unionVariants[t]; !ok {
187			return indirectUnmarshalerDecoder
188		}
189	}
190	d.root = false
191
192	if _, ok := unionRegistry[t]; ok {
193		if isStructUnion(t) {
194			return d.newStructUnionDecoder(t)
195		}
196		return d.newUnionDecoder(t)
197	}
198
199	switch t.Kind() {
200	case reflect.Pointer:
201		inner := t.Elem()
202		innerDecoder := d.typeDecoder(inner)
203
204		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
205			if !v.IsValid() {
206				return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v)
207			}
208
209			newValue := reflect.New(inner).Elem()
210			err := innerDecoder(n, newValue, state)
211			if err != nil {
212				return err
213			}
214
215			v.Set(newValue.Addr())
216			return nil
217		}
218	case reflect.Struct:
219		if isStructUnion(t) {
220			return d.newStructUnionDecoder(t)
221		}
222		return d.newStructTypeDecoder(t)
223	case reflect.Array:
224		fallthrough
225	case reflect.Slice:
226		return d.newArrayTypeDecoder(t)
227	case reflect.Map:
228		return d.newMapDecoder(t)
229	case reflect.Interface:
230		return func(node gjson.Result, value reflect.Value, state *decoderState) error {
231			if !value.IsValid() {
232				return fmt.Errorf("apijson: unexpected invalid value %+#v", value)
233			}
234			if node.Value() != nil && value.CanSet() {
235				value.Set(reflect.ValueOf(node.Value()))
236			}
237			return nil
238		}
239	default:
240		return d.newPrimitiveTypeDecoder(t)
241	}
242}
243
244func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc {
245	keyType := t.Key()
246	itemType := t.Elem()
247	itemDecoder := d.typeDecoder(itemType)
248
249	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
250		mapValue := reflect.MakeMapWithSize(t, len(node.Map()))
251
252		node.ForEach(func(key, value gjson.Result) bool {
253			// It's fine for us to just use `ValueOf` here because the key types will
254			// always be primitive types so we don't need to decode it using the standard pattern
255			keyValue := reflect.ValueOf(key.Value())
256			if !keyValue.IsValid() {
257				if err == nil {
258					err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String())
259				}
260				return false
261			}
262			if keyValue.Type() != keyType {
263				if err == nil {
264					err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type())
265				}
266				return false
267			}
268
269			itemValue := reflect.New(itemType).Elem()
270			itemerr := itemDecoder(value, itemValue, state)
271			if itemerr != nil {
272				if err == nil {
273					err = itemerr
274				}
275				return false
276			}
277
278			mapValue.SetMapIndex(keyValue, itemValue)
279			return true
280		})
281
282		if err != nil {
283			return err
284		}
285		value.Set(mapValue)
286		return nil
287	}
288}
289
290func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc {
291	itemDecoder := d.typeDecoder(t.Elem())
292
293	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
294		if !node.IsArray() {
295			return fmt.Errorf("apijson: could not deserialize to an array")
296		}
297
298		arrayNode := node.Array()
299
300		arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode))
301		for i, itemNode := range arrayNode {
302			err = itemDecoder(itemNode, arrayValue.Index(i), state)
303			if err != nil {
304				return err
305			}
306		}
307
308		value.Set(arrayValue)
309		return nil
310	}
311}
312
313func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
314	// map of json field name to struct field decoders
315	decoderFields := map[string]decoderField{}
316	anonymousDecoders := []decoderField{}
317	extraDecoder := (*decoderField)(nil)
318	var inlineDecoders []decoderField
319
320	validationEntries := validationRegistry[t]
321
322	for i := 0; i < t.NumField(); i++ {
323		idx := []int{i}
324		field := t.FieldByIndex(idx)
325		if !field.IsExported() {
326			continue
327		}
328
329		var validator *validationEntry
330		for _, entry := range validationEntries {
331			if entry.field.Offset == field.Offset {
332				validator = &entry
333				break
334			}
335		}
336
337		// If this is an embedded struct, traverse one level deeper to extract
338		// the fields and get their encoders as well.
339		if field.Anonymous {
340			anonymousDecoders = append(anonymousDecoders, decoderField{
341				fn:  d.typeDecoder(field.Type),
342				idx: idx[:],
343			})
344			continue
345		}
346		// If json tag is not present, then we skip, which is intentionally
347		// different behavior from the stdlib.
348		ptag, ok := parseJSONStructTag(field)
349		if !ok {
350			continue
351		}
352		// We only want to support unexported fields if they're tagged with
353		// `extras` because that field shouldn't be part of the public API.
354		if ptag.extras {
355			extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name}
356			continue
357		}
358		if ptag.inline {
359			df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
360			inlineDecoders = append(inlineDecoders, df)
361			continue
362		}
363		if ptag.metadata {
364			continue
365		}
366
367		oldFormat := d.dateFormat
368		dateFormat, ok := parseFormatStructTag(field)
369		if ok {
370			switch dateFormat {
371			case "date-time":
372				d.dateFormat = time.RFC3339
373			case "date":
374				d.dateFormat = "2006-01-02"
375			}
376		}
377
378		decoderFields[ptag.name] = decoderField{
379			ptag,
380			d.validatedTypeDecoder(field.Type, validator),
381			idx, field.Name,
382		}
383
384		d.dateFormat = oldFormat
385	}
386
387	return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) {
388		if field := value.FieldByName("JSON"); field.IsValid() {
389			if raw := field.FieldByName("raw"); raw.IsValid() {
390				setUnexportedField(raw, node.Raw)
391			}
392		}
393
394		for _, decoder := range anonymousDecoders {
395			// ignore errors
396			decoder.fn(node, value.FieldByIndex(decoder.idx), state)
397		}
398
399		for _, inlineDecoder := range inlineDecoders {
400			var meta Field
401			dest := value.FieldByIndex(inlineDecoder.idx)
402			isValid := false
403			if dest.IsValid() && node.Type != gjson.Null {
404				inlineState := decoderState{exactness: state.exactness, strict: true}
405				err = inlineDecoder.fn(node, dest, &inlineState)
406				if err == nil {
407					isValid = true
408				}
409			}
410
411			if node.Type == gjson.Null {
412				meta = Field{
413					raw:    node.Raw,
414					status: null,
415				}
416			} else if !isValid {
417				// If an inline decoder fails, unset the field and move on.
418				if dest.IsValid() {
419					dest.SetZero()
420				}
421				continue
422			} else if isValid {
423				meta = Field{
424					raw:    node.Raw,
425					status: valid,
426				}
427			}
428			setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta)
429		}
430
431		typedExtraType := reflect.Type(nil)
432		typedExtraFields := reflect.Value{}
433		if extraDecoder != nil {
434			typedExtraType = value.FieldByIndex(extraDecoder.idx).Type()
435			typedExtraFields = reflect.MakeMap(typedExtraType)
436		}
437		untypedExtraFields := map[string]Field{}
438
439		for fieldName, itemNode := range node.Map() {
440			df, explicit := decoderFields[fieldName]
441			var (
442				dest reflect.Value
443				fn   decoderFunc
444				meta Field
445			)
446			if explicit {
447				fn = df.fn
448				dest = value.FieldByIndex(df.idx)
449			}
450			if !explicit && extraDecoder != nil {
451				dest = reflect.New(typedExtraType.Elem()).Elem()
452				fn = extraDecoder.fn
453			}
454
455			isValid := false
456			if dest.IsValid() && itemNode.Type != gjson.Null {
457				err = fn(itemNode, dest, state)
458				if err == nil {
459					isValid = true
460				}
461			}
462
463			// Handle null [param.Opt]
464			if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
465				dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw))
466				continue
467			}
468
469			if itemNode.Type == gjson.Null {
470				meta = Field{
471					raw:    itemNode.Raw,
472					status: null,
473				}
474			} else if !isValid {
475				meta = Field{
476					raw:    itemNode.Raw,
477					status: invalid,
478				}
479			} else if isValid {
480				meta = Field{
481					raw:    itemNode.Raw,
482					status: valid,
483				}
484			}
485
486			if explicit {
487				setMetadataSubField(value, df.idx, df.goname, meta)
488			}
489			if !explicit {
490				untypedExtraFields[fieldName] = meta
491			}
492			if !explicit && extraDecoder != nil {
493				typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest)
494			}
495		}
496
497		if extraDecoder != nil && typedExtraFields.Len() > 0 {
498			value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields)
499		}
500
501		// Set exactness to 'extras' if there are untyped, extra fields.
502		if len(untypedExtraFields) > 0 && state.exactness > extras {
503			state.exactness = extras
504		}
505
506		if len(untypedExtraFields) > 0 {
507			setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields)
508		}
509		return nil
510	}
511}
512
513func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
514	switch t.Kind() {
515	case reflect.String:
516		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
517			v.SetString(n.String())
518			if guardStrict(state, n.Type != gjson.String) {
519				return fmt.Errorf("apijson: failed to parse string strictly")
520			}
521			// Everything that is not an object can be loosely stringified.
522			if n.Type == gjson.JSON {
523				return fmt.Errorf("apijson: failed to parse string")
524			}
525
526			state.validateString(v)
527
528			if guardUnknown(state, v) {
529				return fmt.Errorf("apijson: failed string enum validation")
530			}
531			return nil
532		}
533	case reflect.Bool:
534		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
535			v.SetBool(n.Bool())
536			if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) {
537				return fmt.Errorf("apijson: failed to parse bool strictly")
538			}
539			// Numbers and strings that are either 'true' or 'false' can be loosely
540			// deserialized as bool.
541			if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON {
542				return fmt.Errorf("apijson: failed to parse bool")
543			}
544
545			state.validateBool(v)
546
547			if guardUnknown(state, v) {
548				return fmt.Errorf("apijson: failed bool enum validation")
549			}
550			return nil
551		}
552	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
553		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
554			v.SetInt(n.Int())
555			if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) {
556				return fmt.Errorf("apijson: failed to parse int strictly")
557			}
558			// Numbers, booleans, and strings that maybe look like numbers can be
559			// loosely deserialized as numbers.
560			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
561				return fmt.Errorf("apijson: failed to parse int")
562			}
563
564			state.validateInt(v)
565
566			if guardUnknown(state, v) {
567				return fmt.Errorf("apijson: failed int enum validation")
568			}
569			return nil
570		}
571	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
572		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
573			v.SetUint(n.Uint())
574			if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) {
575				return fmt.Errorf("apijson: failed to parse uint strictly")
576			}
577			// Numbers, booleans, and strings that maybe look like numbers can be
578			// loosely deserialized as uint.
579			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
580				return fmt.Errorf("apijson: failed to parse uint")
581			}
582			if guardUnknown(state, v) {
583				return fmt.Errorf("apijson: failed uint enum validation")
584			}
585			return nil
586		}
587	case reflect.Float32, reflect.Float64:
588		return func(n gjson.Result, v reflect.Value, state *decoderState) error {
589			v.SetFloat(n.Float())
590			if guardStrict(state, n.Type != gjson.Number) {
591				return fmt.Errorf("apijson: failed to parse float strictly")
592			}
593			// Numbers, booleans, and strings that maybe look like numbers can be
594			// loosely deserialized as floats.
595			if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
596				return fmt.Errorf("apijson: failed to parse float")
597			}
598			if guardUnknown(state, v) {
599				return fmt.Errorf("apijson: failed float enum validation")
600			}
601			return nil
602		}
603	default:
604		return func(node gjson.Result, v reflect.Value, state *decoderState) error {
605			return fmt.Errorf("unknown type received at primitive decoder: %s", t.String())
606		}
607	}
608}
609
610func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc {
611	for t.Kind() == reflect.Pointer {
612		t = t.Elem()
613	}
614	valueField, _ := t.FieldByName("Value")
615	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
616		state.validateOptKind(n, valueField.Type)
617		return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
618	}
619}
620
621func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
622	format := d.dateFormat
623	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
624		parsed, err := time.Parse(format, n.Str)
625		if err == nil {
626			v.Set(reflect.ValueOf(parsed).Convert(t))
627			return nil
628		}
629
630		if guardStrict(state, true) {
631			return err
632		}
633
634		layouts := []string{
635			"2006-01-02",
636			"2006-01-02T15:04:05Z07:00",
637			"2006-01-02T15:04:05Z0700",
638			"2006-01-02T15:04:05",
639			"2006-01-02 15:04:05Z07:00",
640			"2006-01-02 15:04:05Z0700",
641			"2006-01-02 15:04:05",
642		}
643
644		for _, layout := range layouts {
645			parsed, err := time.Parse(layout, n.Str)
646			if err == nil {
647				v.Set(reflect.ValueOf(parsed).Convert(t))
648				return nil
649			}
650		}
651
652		return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str)
653	}
654}
655
656func setUnexportedField(field reflect.Value, value any) {
657	reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value))
658}
659
660func guardStrict(state *decoderState, cond bool) bool {
661	if !cond {
662		return false
663	}
664
665	if state.strict {
666		return true
667	}
668
669	state.exactness = loose
670	return false
671}
672
673func canParseAsNumber(str string) bool {
674	_, err := strconv.ParseFloat(str, 64)
675	return err == nil
676}
677
678var stringType = reflect.TypeOf(string(""))
679
680func guardUnknown(state *decoderState, v reflect.Value) bool {
681	if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) {
682		return true
683	}
684
685	constantString, ok := v.Interface().(interface{ Default() string })
686	named := v.Type() != stringType
687	if guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) {
688		return true
689	}
690	return false
691}