1package apijson
2
3import (
4 "errors"
5 "github.com/openai/openai-go/packages/param"
6 "reflect"
7
8 "github.com/tidwall/gjson"
9)
10
11func isEmbeddedUnion(t reflect.Type) bool {
12 var apiunion param.APIUnion
13 for i := 0; i < t.NumField(); i++ {
14 if t.Field(i).Type == reflect.TypeOf(apiunion) && t.Field(i).Anonymous {
15 return true
16 }
17 }
18 return false
19}
20
21func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) {
22 var t T
23 entry := unionEntry{
24 discriminatorKey: key,
25 variants: []UnionVariant{},
26 }
27 for k, typ := range mappings {
28 entry.variants = append(entry.variants, UnionVariant{
29 DiscriminatorValue: k,
30 Type: typ,
31 })
32 }
33 unionRegistry[reflect.TypeOf(t)] = entry
34}
35
36func (d *decoderBuilder) newEmbeddedUnionDecoder(t reflect.Type) decoderFunc {
37 decoders := []decoderFunc{}
38
39 for i := 0; i < t.NumField(); i++ {
40 variant := t.Field(i)
41 decoder := d.typeDecoder(variant.Type)
42 decoders = append(decoders, decoder)
43 }
44
45 unionEntry := unionEntry{
46 variants: []UnionVariant{},
47 }
48
49 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
50 // If there is a discriminator match, circumvent the exactness logic entirely
51 for idx, variant := range unionEntry.variants {
52 decoder := decoders[idx]
53 if variant.TypeFilter != n.Type {
54 continue
55 }
56
57 if len(unionEntry.discriminatorKey) != 0 {
58 discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
59 if discriminatorValue == variant.DiscriminatorValue {
60 inner := reflect.New(variant.Type).Elem()
61 err := decoder(n, inner, state)
62 v.Set(inner)
63 return err
64 }
65 }
66 }
67
68 // Set bestExactness to worse than loose
69 bestExactness := loose - 1
70 for idx, variant := range unionEntry.variants {
71 decoder := decoders[idx]
72 if variant.TypeFilter != n.Type {
73 continue
74 }
75 sub := decoderState{strict: state.strict, exactness: exact}
76 inner := reflect.New(variant.Type).Elem()
77 err := decoder(n, inner, &sub)
78 if err != nil {
79 continue
80 }
81 if sub.exactness == exact {
82 v.Set(inner)
83 return nil
84 }
85 if sub.exactness > bestExactness {
86 v.Set(inner)
87 bestExactness = sub.exactness
88 }
89 }
90
91 if bestExactness < loose {
92 return errors.New("apijson: was not able to coerce type as union")
93 }
94
95 if guardStrict(state, bestExactness != exact) {
96 return errors.New("apijson: was not able to coerce type as union strictly")
97 }
98
99 return nil
100 }
101}