1package apiquery
2
3import (
4 "encoding/json"
5 "fmt"
6 "reflect"
7 "strconv"
8 "strings"
9 "sync"
10 "time"
11
12 internalparam "github.com/openai/openai-go/internal/param"
13 "github.com/openai/openai-go/packages/param"
14)
15
16var encoders sync.Map // map[reflect.Type]encoderFunc
17
18type encoder struct {
19 dateFormat string
20 root bool
21 settings QuerySettings
22}
23
24type encoderFunc func(key string, value reflect.Value) []Pair
25
26type encoderField struct {
27 tag parsedStructTag
28 fn encoderFunc
29 idx []int
30}
31
32type encoderEntry struct {
33 reflect.Type
34 dateFormat string
35 root bool
36 settings QuerySettings
37}
38
39type Pair struct {
40 key string
41 value string
42}
43
44func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
45 entry := encoderEntry{
46 Type: t,
47 dateFormat: e.dateFormat,
48 root: e.root,
49 settings: e.settings,
50 }
51
52 if fi, ok := encoders.Load(entry); ok {
53 return fi.(encoderFunc)
54 }
55
56 // To deal with recursive types, populate the map with an
57 // indirect func before we build it. This type waits on the
58 // real func (f) to be ready and then calls it. This indirect
59 // func is only used for recursive types.
60 var (
61 wg sync.WaitGroup
62 f encoderFunc
63 )
64 wg.Add(1)
65 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) []Pair {
66 wg.Wait()
67 return f(key, v)
68 }))
69 if loaded {
70 return fi.(encoderFunc)
71 }
72
73 // Compute the real encoder and replace the indirect func with it.
74 f = e.newTypeEncoder(t)
75 wg.Done()
76 encoders.Store(entry, f)
77 return f
78}
79
80func marshalerEncoder(key string, value reflect.Value) []Pair {
81 s, _ := value.Interface().(json.Marshaler).MarshalJSON()
82 return []Pair{{key, string(s)}}
83}
84
85func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
86 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
87 return e.newTimeTypeEncoder(t)
88 }
89
90 if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) && param.OptionalPrimitiveTypes[t] == nil {
91 return marshalerEncoder
92 }
93 e.root = false
94 switch t.Kind() {
95 case reflect.Pointer:
96 encoder := e.typeEncoder(t.Elem())
97 return func(key string, value reflect.Value) (pairs []Pair) {
98 if !value.IsValid() || value.IsNil() {
99 return
100 }
101 pairs = encoder(key, value.Elem())
102 return
103 }
104 case reflect.Struct:
105 return e.newStructTypeEncoder(t)
106 case reflect.Array:
107 fallthrough
108 case reflect.Slice:
109 return e.newArrayTypeEncoder(t)
110 case reflect.Map:
111 return e.newMapEncoder(t)
112 case reflect.Interface:
113 return e.newInterfaceEncoder()
114 default:
115 return e.newPrimitiveTypeEncoder(t)
116 }
117}
118
119func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
120 if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
121 return e.newFieldTypeEncoder(t)
122 }
123
124 if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
125 return e.newRichFieldTypeEncoder(t, idx)
126 }
127
128 encoderFields := []encoderField{}
129
130 // This helper allows us to recursively collect field encoders into a flat
131 // array. The parameter `index` keeps track of the access patterns necessary
132 // to get to some field.
133 var collectEncoderFields func(r reflect.Type, index []int)
134 collectEncoderFields = func(r reflect.Type, index []int) {
135 for i := 0; i < r.NumField(); i++ {
136 idx := append(index, i)
137 field := t.FieldByIndex(idx)
138 if !field.IsExported() {
139 continue
140 }
141 // If this is an embedded struct, traverse one level deeper to extract
142 // the field and get their encoders as well.
143 if field.Anonymous {
144 collectEncoderFields(field.Type, idx)
145 continue
146 }
147 // If query tag is not present, then we skip, which is intentionally
148 // different behavior from the stdlib.
149 ptag, ok := parseQueryStructTag(field)
150 if !ok {
151 continue
152 }
153
154 if (ptag.name == "-" || ptag.name == "") && !ptag.inline {
155 continue
156 }
157
158 dateFormat, ok := parseFormatStructTag(field)
159 oldFormat := e.dateFormat
160 if ok {
161 switch dateFormat {
162 case "date-time":
163 e.dateFormat = time.RFC3339
164 case "date":
165 e.dateFormat = "2006-01-02"
166 }
167 }
168 var encoderFn encoderFunc
169 if ptag.omitzero {
170 typeEncoderFn := e.typeEncoder(field.Type)
171 encoderFn = func(key string, value reflect.Value) []Pair {
172 if value.IsZero() {
173 return nil
174 }
175 return typeEncoderFn(key, value)
176 }
177 } else {
178 encoderFn = e.typeEncoder(field.Type)
179 }
180 encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
181 e.dateFormat = oldFormat
182 }
183 }
184 collectEncoderFields(t, []int{})
185
186 return func(key string, value reflect.Value) (pairs []Pair) {
187 for _, ef := range encoderFields {
188 var subkey string = e.renderKeyPath(key, ef.tag.name)
189 if ef.tag.inline {
190 subkey = key
191 }
192
193 field := value.FieldByIndex(ef.idx)
194 pairs = append(pairs, ef.fn(subkey, field)...)
195 }
196 return
197 }
198}
199
200func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
201 keyEncoder := e.typeEncoder(t.Key())
202 elementEncoder := e.typeEncoder(t.Elem())
203 return func(key string, value reflect.Value) (pairs []Pair) {
204 iter := value.MapRange()
205 for iter.Next() {
206 encodedKey := keyEncoder("", iter.Key())
207 if len(encodedKey) != 1 {
208 panic("Unexpected number of parts for encoded map key. Are you using a non-primitive for this map?")
209 }
210 subkey := encodedKey[0].value
211 keyPath := e.renderKeyPath(key, subkey)
212 pairs = append(pairs, elementEncoder(keyPath, iter.Value())...)
213 }
214 return
215 }
216}
217
218func (e *encoder) renderKeyPath(key string, subkey string) string {
219 if len(key) == 0 {
220 return subkey
221 }
222 if e.settings.NestedFormat == NestedQueryFormatDots {
223 return fmt.Sprintf("%s.%s", key, subkey)
224 }
225 return fmt.Sprintf("%s[%s]", key, subkey)
226}
227
228func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
229 switch e.settings.ArrayFormat {
230 case ArrayQueryFormatComma:
231 innerEncoder := e.typeEncoder(t.Elem())
232 return func(key string, v reflect.Value) []Pair {
233 elements := []string{}
234 for i := 0; i < v.Len(); i++ {
235 for _, pair := range innerEncoder("", v.Index(i)) {
236 elements = append(elements, pair.value)
237 }
238 }
239 if len(elements) == 0 {
240 return []Pair{}
241 }
242 return []Pair{{key, strings.Join(elements, ",")}}
243 }
244 case ArrayQueryFormatRepeat:
245 innerEncoder := e.typeEncoder(t.Elem())
246 return func(key string, value reflect.Value) (pairs []Pair) {
247 for i := 0; i < value.Len(); i++ {
248 pairs = append(pairs, innerEncoder(key, value.Index(i))...)
249 }
250 return pairs
251 }
252 case ArrayQueryFormatIndices:
253 panic("The array indices format is not supported yet")
254 case ArrayQueryFormatBrackets:
255 innerEncoder := e.typeEncoder(t.Elem())
256 return func(key string, value reflect.Value) []Pair {
257 pairs := []Pair{}
258 for i := 0; i < value.Len(); i++ {
259 pairs = append(pairs, innerEncoder(key+"[]", value.Index(i))...)
260 }
261 return pairs
262 }
263 default:
264 panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat))
265 }
266}
267
268func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
269 switch t.Kind() {
270 case reflect.Pointer:
271 inner := t.Elem()
272
273 innerEncoder := e.newPrimitiveTypeEncoder(inner)
274 return func(key string, v reflect.Value) []Pair {
275 if !v.IsValid() || v.IsNil() {
276 return nil
277 }
278 return innerEncoder(key, v.Elem())
279 }
280 case reflect.String:
281 return func(key string, v reflect.Value) []Pair {
282 return []Pair{{key, v.String()}}
283 }
284 case reflect.Bool:
285 return func(key string, v reflect.Value) []Pair {
286 if v.Bool() {
287 return []Pair{{key, "true"}}
288 }
289 return []Pair{{key, "false"}}
290 }
291 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
292 return func(key string, v reflect.Value) []Pair {
293 return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}
294 }
295 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
296 return func(key string, v reflect.Value) []Pair {
297 return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}
298 }
299 case reflect.Float32, reflect.Float64:
300 return func(key string, v reflect.Value) []Pair {
301 return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}
302 }
303 case reflect.Complex64, reflect.Complex128:
304 bitSize := 64
305 if t.Kind() == reflect.Complex128 {
306 bitSize = 128
307 }
308 return func(key string, v reflect.Value) []Pair {
309 return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}
310 }
311 default:
312 return func(key string, v reflect.Value) []Pair {
313 return nil
314 }
315 }
316}
317
318func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
319 f, _ := t.FieldByName("Value")
320 enc := e.typeEncoder(f.Type)
321
322 return func(key string, value reflect.Value) []Pair {
323 present := value.FieldByName("Present")
324 if !present.Bool() {
325 return nil
326 }
327 null := value.FieldByName("Null")
328 if null.Bool() {
329 // TODO: Error?
330 return nil
331 }
332 raw := value.FieldByName("Raw")
333 if !raw.IsNil() {
334 return e.typeEncoder(raw.Type())(key, raw)
335 }
336 return enc(key, value.FieldByName("Value"))
337 }
338}
339
340func (e *encoder) newTimeTypeEncoder(t reflect.Type) encoderFunc {
341 format := e.dateFormat
342 return func(key string, value reflect.Value) []Pair {
343 return []Pair{{
344 key,
345 value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format),
346 }}
347 }
348}
349
350func (e encoder) newInterfaceEncoder() encoderFunc {
351 return func(key string, value reflect.Value) []Pair {
352 value = value.Elem()
353 if !value.IsValid() {
354 return nil
355 }
356 return e.typeEncoder(value.Type())(key, value)
357 }
358
359}