1package apijson
2
3import (
4 "errors"
5 "github.com/openai/openai-go/packages/param"
6 "reflect"
7
8 "github.com/tidwall/gjson"
9)
10
11var apiUnionType = reflect.TypeOf(param.APIUnion{})
12
13func isStructUnion(t reflect.Type) bool {
14 if t.Kind() != reflect.Struct {
15 return false
16 }
17 for i := 0; i < t.NumField(); i++ {
18 if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous {
19 return true
20 }
21 }
22 return false
23}
24
25func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) {
26 var t T
27 entry := unionEntry{
28 discriminatorKey: key,
29 variants: []UnionVariant{},
30 }
31 for k, typ := range mappings {
32 entry.variants = append(entry.variants, UnionVariant{
33 DiscriminatorValue: k,
34 Type: typ,
35 })
36 }
37 unionRegistry[reflect.TypeOf(t)] = entry
38}
39
40func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
41 type variantDecoder struct {
42 decoder decoderFunc
43 field reflect.StructField
44 discriminatorValue any
45 }
46
47 variants := []variantDecoder{}
48 for i := 0; i < t.NumField(); i++ {
49 field := t.Field(i)
50
51 if field.Anonymous && field.Type == apiUnionType {
52 continue
53 }
54
55 decoder := d.typeDecoder(field.Type)
56 variants = append(variants, variantDecoder{
57 decoder: decoder,
58 field: field,
59 })
60 }
61
62 unionEntry, discriminated := unionRegistry[t]
63 for _, unionVariant := range unionEntry.variants {
64 for i := 0; i < len(variants); i++ {
65 variant := &variants[i]
66 if variant.field.Type.Elem() == unionVariant.Type {
67 variant.discriminatorValue = unionVariant.DiscriminatorValue
68 break
69 }
70 }
71 }
72
73 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
74 if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 {
75 discriminator := n.Get(unionEntry.discriminatorKey).Value()
76 for _, variant := range variants {
77 if discriminator == variant.discriminatorValue {
78 inner := v.FieldByIndex(variant.field.Index)
79 return variant.decoder(n, inner, state)
80 }
81 }
82 return errors.New("apijson: was not able to find discriminated union variant")
83 }
84
85 // Set bestExactness to worse than loose
86 bestExactness := loose - 1
87 bestVariant := -1
88 for i, variant := range variants {
89 // Pointers are used to discern JSON object variants from value variants
90 if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr {
91 continue
92 }
93
94 sub := decoderState{strict: state.strict, exactness: exact}
95 inner := v.FieldByIndex(variant.field.Index)
96 err := variant.decoder(n, inner, &sub)
97 if err != nil {
98 continue
99 }
100 if sub.exactness == exact {
101 bestExactness = exact
102 bestVariant = i
103 break
104 }
105 if sub.exactness > bestExactness {
106 bestExactness = sub.exactness
107 bestVariant = i
108 }
109 }
110
111 if bestExactness < loose {
112 return errors.New("apijson: was not able to coerce type as union")
113 }
114
115 if guardStrict(state, bestExactness != exact) {
116 return errors.New("apijson: was not able to coerce type as union strictly")
117 }
118
119 for i := 0; i < len(variants); i++ {
120 if i == bestVariant {
121 continue
122 }
123 v.FieldByIndex(variants[i].field.Index).SetZero()
124 }
125
126 return nil
127 }
128}
129
130// newUnionDecoder returns a decoderFunc that deserializes into a union using an
131// algorithm roughly similar to Pydantic's [smart algorithm].
132//
133// Conceptually this is equivalent to choosing the best schema based on how 'exact'
134// the deserialization is for each of the schemas.
135//
136// If there is a tie in the level of exactness, then the tie is broken
137// left-to-right.
138//
139// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
140func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
141 unionEntry, ok := unionRegistry[t]
142 if !ok {
143 panic("apijson: couldn't find union of type " + t.String() + " in union registry")
144 }
145 decoders := []decoderFunc{}
146 for _, variant := range unionEntry.variants {
147 decoder := d.typeDecoder(variant.Type)
148 decoders = append(decoders, decoder)
149 }
150 return func(n gjson.Result, v reflect.Value, state *decoderState) error {
151 // If there is a discriminator match, circumvent the exactness logic entirely
152 for idx, variant := range unionEntry.variants {
153 decoder := decoders[idx]
154 if variant.TypeFilter != n.Type {
155 continue
156 }
157
158 if len(unionEntry.discriminatorKey) != 0 {
159 discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
160 if discriminatorValue == variant.DiscriminatorValue {
161 inner := reflect.New(variant.Type).Elem()
162 err := decoder(n, inner, state)
163 v.Set(inner)
164 return err
165 }
166 }
167 }
168
169 // Set bestExactness to worse than loose
170 bestExactness := loose - 1
171 for idx, variant := range unionEntry.variants {
172 decoder := decoders[idx]
173 if variant.TypeFilter != n.Type {
174 continue
175 }
176 sub := decoderState{strict: state.strict, exactness: exact}
177 inner := reflect.New(variant.Type).Elem()
178 err := decoder(n, inner, &sub)
179 if err != nil {
180 continue
181 }
182 if sub.exactness == exact {
183 v.Set(inner)
184 return nil
185 }
186 if sub.exactness > bestExactness {
187 v.Set(inner)
188 bestExactness = sub.exactness
189 }
190 }
191
192 if bestExactness < loose {
193 return errors.New("apijson: was not able to coerce type as union")
194 }
195
196 if guardStrict(state, bestExactness != exact) {
197 return errors.New("apijson: was not able to coerce type as union strictly")
198 }
199
200 return nil
201 }
202}