union.go

  1package apijson
  2
  3import (
  4	"errors"
  5	"github.com/openai/openai-go/packages/param"
  6	"reflect"
  7
  8	"github.com/tidwall/gjson"
  9)
 10
 11func isEmbeddedUnion(t reflect.Type) bool {
 12	var apiunion param.APIUnion
 13	for i := 0; i < t.NumField(); i++ {
 14		if t.Field(i).Type == reflect.TypeOf(apiunion) && t.Field(i).Anonymous {
 15			return true
 16		}
 17	}
 18	return false
 19}
 20
 21func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) {
 22	var t T
 23	entry := unionEntry{
 24		discriminatorKey: key,
 25		variants:         []UnionVariant{},
 26	}
 27	for k, typ := range mappings {
 28		entry.variants = append(entry.variants, UnionVariant{
 29			DiscriminatorValue: k,
 30			Type:               typ,
 31		})
 32	}
 33	unionRegistry[reflect.TypeOf(t)] = entry
 34}
 35
 36func (d *decoderBuilder) newEmbeddedUnionDecoder(t reflect.Type) decoderFunc {
 37	decoders := []decoderFunc{}
 38
 39	for i := 0; i < t.NumField(); i++ {
 40		variant := t.Field(i)
 41		decoder := d.typeDecoder(variant.Type)
 42		decoders = append(decoders, decoder)
 43	}
 44
 45	unionEntry := unionEntry{
 46		variants: []UnionVariant{},
 47	}
 48
 49	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
 50		// If there is a discriminator match, circumvent the exactness logic entirely
 51		for idx, variant := range unionEntry.variants {
 52			decoder := decoders[idx]
 53			if variant.TypeFilter != n.Type {
 54				continue
 55			}
 56
 57			if len(unionEntry.discriminatorKey) != 0 {
 58				discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
 59				if discriminatorValue == variant.DiscriminatorValue {
 60					inner := reflect.New(variant.Type).Elem()
 61					err := decoder(n, inner, state)
 62					v.Set(inner)
 63					return err
 64				}
 65			}
 66		}
 67
 68		// Set bestExactness to worse than loose
 69		bestExactness := loose - 1
 70		for idx, variant := range unionEntry.variants {
 71			decoder := decoders[idx]
 72			if variant.TypeFilter != n.Type {
 73				continue
 74			}
 75			sub := decoderState{strict: state.strict, exactness: exact}
 76			inner := reflect.New(variant.Type).Elem()
 77			err := decoder(n, inner, &sub)
 78			if err != nil {
 79				continue
 80			}
 81			if sub.exactness == exact {
 82				v.Set(inner)
 83				return nil
 84			}
 85			if sub.exactness > bestExactness {
 86				v.Set(inner)
 87				bestExactness = sub.exactness
 88			}
 89		}
 90
 91		if bestExactness < loose {
 92			return errors.New("apijson: was not able to coerce type as union")
 93		}
 94
 95		if guardStrict(state, bestExactness != exact) {
 96			return errors.New("apijson: was not able to coerce type as union strictly")
 97		}
 98
 99		return nil
100	}
101}