union.go

  1package apijson
  2
  3import (
  4	"errors"
  5	"github.com/openai/openai-go/packages/param"
  6	"reflect"
  7
  8	"github.com/tidwall/gjson"
  9)
 10
 11var apiUnionType = reflect.TypeOf(param.APIUnion{})
 12
 13func isStructUnion(t reflect.Type) bool {
 14	if t.Kind() != reflect.Struct {
 15		return false
 16	}
 17	for i := 0; i < t.NumField(); i++ {
 18		if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous {
 19			return true
 20		}
 21	}
 22	return false
 23}
 24
 25func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) {
 26	var t T
 27	entry := unionEntry{
 28		discriminatorKey: key,
 29		variants:         []UnionVariant{},
 30	}
 31	for k, typ := range mappings {
 32		entry.variants = append(entry.variants, UnionVariant{
 33			DiscriminatorValue: k,
 34			Type:               typ,
 35		})
 36	}
 37	unionRegistry[reflect.TypeOf(t)] = entry
 38}
 39
 40func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
 41	type variantDecoder struct {
 42		decoder            decoderFunc
 43		field              reflect.StructField
 44		discriminatorValue any
 45	}
 46
 47	variants := []variantDecoder{}
 48	for i := 0; i < t.NumField(); i++ {
 49		field := t.Field(i)
 50
 51		if field.Anonymous && field.Type == apiUnionType {
 52			continue
 53		}
 54
 55		decoder := d.typeDecoder(field.Type)
 56		variants = append(variants, variantDecoder{
 57			decoder: decoder,
 58			field:   field,
 59		})
 60	}
 61
 62	unionEntry, discriminated := unionRegistry[t]
 63	for _, unionVariant := range unionEntry.variants {
 64		for i := 0; i < len(variants); i++ {
 65			variant := &variants[i]
 66			if variant.field.Type.Elem() == unionVariant.Type {
 67				variant.discriminatorValue = unionVariant.DiscriminatorValue
 68				break
 69			}
 70		}
 71	}
 72
 73	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
 74		if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 {
 75			discriminator := n.Get(unionEntry.discriminatorKey).Value()
 76			for _, variant := range variants {
 77				if discriminator == variant.discriminatorValue {
 78					inner := v.FieldByIndex(variant.field.Index)
 79					return variant.decoder(n, inner, state)
 80				}
 81			}
 82			return errors.New("apijson: was not able to find discriminated union variant")
 83		}
 84
 85		// Set bestExactness to worse than loose
 86		bestExactness := loose - 1
 87		bestVariant := -1
 88		for i, variant := range variants {
 89			// Pointers are used to discern JSON object variants from value variants
 90			if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr {
 91				continue
 92			}
 93
 94			sub := decoderState{strict: state.strict, exactness: exact}
 95			inner := v.FieldByIndex(variant.field.Index)
 96			err := variant.decoder(n, inner, &sub)
 97			if err != nil {
 98				continue
 99			}
100			if sub.exactness == exact {
101				bestExactness = exact
102				bestVariant = i
103				break
104			}
105			if sub.exactness > bestExactness {
106				bestExactness = sub.exactness
107				bestVariant = i
108			}
109		}
110
111		if bestExactness < loose {
112			return errors.New("apijson: was not able to coerce type as union")
113		}
114
115		if guardStrict(state, bestExactness != exact) {
116			return errors.New("apijson: was not able to coerce type as union strictly")
117		}
118
119		for i := 0; i < len(variants); i++ {
120			if i == bestVariant {
121				continue
122			}
123			v.FieldByIndex(variants[i].field.Index).SetZero()
124		}
125
126		return nil
127	}
128}
129
130// newUnionDecoder returns a decoderFunc that deserializes into a union using an
131// algorithm roughly similar to Pydantic's [smart algorithm].
132//
133// Conceptually this is equivalent to choosing the best schema based on how 'exact'
134// the deserialization is for each of the schemas.
135//
136// If there is a tie in the level of exactness, then the tie is broken
137// left-to-right.
138//
139// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
140func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
141	unionEntry, ok := unionRegistry[t]
142	if !ok {
143		panic("apijson: couldn't find union of type " + t.String() + " in union registry")
144	}
145	decoders := []decoderFunc{}
146	for _, variant := range unionEntry.variants {
147		decoder := d.typeDecoder(variant.Type)
148		decoders = append(decoders, decoder)
149	}
150	return func(n gjson.Result, v reflect.Value, state *decoderState) error {
151		// If there is a discriminator match, circumvent the exactness logic entirely
152		for idx, variant := range unionEntry.variants {
153			decoder := decoders[idx]
154			if variant.TypeFilter != n.Type {
155				continue
156			}
157
158			if len(unionEntry.discriminatorKey) != 0 {
159				discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
160				if discriminatorValue == variant.DiscriminatorValue {
161					inner := reflect.New(variant.Type).Elem()
162					err := decoder(n, inner, state)
163					v.Set(inner)
164					return err
165				}
166			}
167		}
168
169		// Set bestExactness to worse than loose
170		bestExactness := loose - 1
171		for idx, variant := range unionEntry.variants {
172			decoder := decoders[idx]
173			if variant.TypeFilter != n.Type {
174				continue
175			}
176			sub := decoderState{strict: state.strict, exactness: exact}
177			inner := reflect.New(variant.Type).Elem()
178			err := decoder(n, inner, &sub)
179			if err != nil {
180				continue
181			}
182			if sub.exactness == exact {
183				v.Set(inner)
184				return nil
185			}
186			if sub.exactness > bestExactness {
187				v.Set(inner)
188				bestExactness = sub.exactness
189			}
190		}
191
192		if bestExactness < loose {
193			return errors.New("apijson: was not able to coerce type as union")
194		}
195
196		if guardStrict(state, bestExactness != exact) {
197			return errors.New("apijson: was not able to coerce type as union strictly")
198		}
199
200		return nil
201	}
202}