1package apiquery
2
3import (
4 "encoding/json"
5 "fmt"
6 "reflect"
7 "strconv"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/anthropics/anthropic-sdk-go/packages/param"
13)
14
15var encoders sync.Map // map[reflect.Type]encoderFunc
16
17type encoder struct {
18 dateFormat string
19 root bool
20 settings QuerySettings
21}
22
23type encoderFunc func(key string, value reflect.Value) ([]Pair, error)
24
25type encoderField struct {
26 tag parsedStructTag
27 fn encoderFunc
28 idx []int
29}
30
31type encoderEntry struct {
32 reflect.Type
33 dateFormat string
34 root bool
35 settings QuerySettings
36}
37
38type Pair struct {
39 key string
40 value string
41}
42
43func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
44 entry := encoderEntry{
45 Type: t,
46 dateFormat: e.dateFormat,
47 root: e.root,
48 settings: e.settings,
49 }
50
51 if fi, ok := encoders.Load(entry); ok {
52 return fi.(encoderFunc)
53 }
54
55 // To deal with recursive types, populate the map with an
56 // indirect func before we build it. This type waits on the
57 // real func (f) to be ready and then calls it. This indirect
58 // func is only used for recursive types.
59 var (
60 wg sync.WaitGroup
61 f encoderFunc
62 )
63 wg.Add(1)
64 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) {
65 wg.Wait()
66 return f(key, v)
67 }))
68 if loaded {
69 return fi.(encoderFunc)
70 }
71
72 // Compute the real encoder and replace the indirect func with it.
73 f = e.newTypeEncoder(t)
74 wg.Done()
75 encoders.Store(entry, f)
76 return f
77}
78
79func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) {
80 s, err := value.Interface().(json.Marshaler).MarshalJSON()
81 if err != nil {
82 return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err)
83 }
84 return []Pair{{key, string(s)}}, nil
85}
86
87func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
88 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
89 return e.newTimeTypeEncoder(t)
90 }
91
92 if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
93 return e.newRichFieldTypeEncoder(t)
94 }
95
96 if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
97 return marshalerEncoder
98 }
99
100 e.root = false
101 switch t.Kind() {
102 case reflect.Pointer:
103 encoder := e.typeEncoder(t.Elem())
104 return func(key string, value reflect.Value) (pairs []Pair, err error) {
105 if !value.IsValid() || value.IsNil() {
106 return
107 }
108 return encoder(key, value.Elem())
109 }
110 case reflect.Struct:
111 return e.newStructTypeEncoder(t)
112 case reflect.Array:
113 fallthrough
114 case reflect.Slice:
115 return e.newArrayTypeEncoder(t)
116 case reflect.Map:
117 return e.newMapEncoder(t)
118 case reflect.Interface:
119 return e.newInterfaceEncoder()
120 default:
121 return e.newPrimitiveTypeEncoder(t)
122 }
123}
124
125func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
126 if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
127 return e.newRichFieldTypeEncoder(t)
128 }
129
130 for i := 0; i < t.NumField(); i++ {
131 if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
132 return e.newStructUnionTypeEncoder(t)
133 }
134 }
135
136 encoderFields := []encoderField{}
137
138 // This helper allows us to recursively collect field encoders into a flat
139 // array. The parameter `index` keeps track of the access patterns necessary
140 // to get to some field.
141 var collectEncoderFields func(r reflect.Type, index []int)
142 collectEncoderFields = func(r reflect.Type, index []int) {
143 for i := 0; i < r.NumField(); i++ {
144 idx := append(index, i)
145 field := t.FieldByIndex(idx)
146 if !field.IsExported() {
147 continue
148 }
149 // If this is an embedded struct, traverse one level deeper to extract
150 // the field and get their encoders as well.
151 if field.Anonymous {
152 collectEncoderFields(field.Type, idx)
153 continue
154 }
155 // If query tag is not present, then we skip, which is intentionally
156 // different behavior from the stdlib.
157 ptag, ok := parseQueryStructTag(field)
158 if !ok {
159 continue
160 }
161
162 if (ptag.name == "-" || ptag.name == "") && !ptag.inline {
163 continue
164 }
165
166 dateFormat, ok := parseFormatStructTag(field)
167 oldFormat := e.dateFormat
168 if ok {
169 switch dateFormat {
170 case "date-time":
171 e.dateFormat = time.RFC3339
172 case "date":
173 e.dateFormat = "2006-01-02"
174 }
175 }
176 var encoderFn encoderFunc
177 if ptag.omitzero {
178 typeEncoderFn := e.typeEncoder(field.Type)
179 encoderFn = func(key string, value reflect.Value) ([]Pair, error) {
180 if value.IsZero() {
181 return nil, nil
182 }
183 return typeEncoderFn(key, value)
184 }
185 } else {
186 encoderFn = e.typeEncoder(field.Type)
187 }
188 encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
189 e.dateFormat = oldFormat
190 }
191 }
192 collectEncoderFields(t, []int{})
193
194 return func(key string, value reflect.Value) (pairs []Pair, err error) {
195 for _, ef := range encoderFields {
196 var subkey string = e.renderKeyPath(key, ef.tag.name)
197 if ef.tag.inline {
198 subkey = key
199 }
200
201 field := value.FieldByIndex(ef.idx)
202 subpairs, suberr := ef.fn(subkey, field)
203 if suberr != nil {
204 err = suberr
205 }
206 pairs = append(pairs, subpairs...)
207 }
208 return
209 }
210}
211
212var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
213
214func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
215 var fieldEncoders []encoderFunc
216 for i := 0; i < t.NumField(); i++ {
217 field := t.Field(i)
218 if field.Type == paramUnionType && field.Anonymous {
219 fieldEncoders = append(fieldEncoders, nil)
220 continue
221 }
222 fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
223 }
224
225 return func(key string, value reflect.Value) (pairs []Pair, err error) {
226 for i := 0; i < t.NumField(); i++ {
227 if value.Field(i).Type() == paramUnionType {
228 continue
229 }
230 if !value.Field(i).IsZero() {
231 return fieldEncoders[i](key, value.Field(i))
232 }
233 }
234 return nil, fmt.Errorf("apiquery: union %s has no field set", t.String())
235 }
236}
237
238func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
239 keyEncoder := e.typeEncoder(t.Key())
240 elementEncoder := e.typeEncoder(t.Elem())
241 return func(key string, value reflect.Value) (pairs []Pair, err error) {
242 iter := value.MapRange()
243 for iter.Next() {
244 encodedKey, err := keyEncoder("", iter.Key())
245 if err != nil {
246 return nil, err
247 }
248 if len(encodedKey) != 1 {
249 return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive")
250 }
251 subkey := encodedKey[0].value
252 keyPath := e.renderKeyPath(key, subkey)
253 subpairs, suberr := elementEncoder(keyPath, iter.Value())
254 if suberr != nil {
255 err = suberr
256 }
257 pairs = append(pairs, subpairs...)
258 }
259 return
260 }
261}
262
263func (e *encoder) renderKeyPath(key string, subkey string) string {
264 if len(key) == 0 {
265 return subkey
266 }
267 if e.settings.NestedFormat == NestedQueryFormatDots {
268 return fmt.Sprintf("%s.%s", key, subkey)
269 }
270 return fmt.Sprintf("%s[%s]", key, subkey)
271}
272
273func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
274 switch e.settings.ArrayFormat {
275 case ArrayQueryFormatComma:
276 innerEncoder := e.typeEncoder(t.Elem())
277 return func(key string, v reflect.Value) ([]Pair, error) {
278 elements := []string{}
279 for i := 0; i < v.Len(); i++ {
280 innerPairs, err := innerEncoder("", v.Index(i))
281 if err != nil {
282 return nil, err
283 }
284 for _, pair := range innerPairs {
285 elements = append(elements, pair.value)
286 }
287 }
288 if len(elements) == 0 {
289 return []Pair{}, nil
290 }
291 return []Pair{{key, strings.Join(elements, ",")}}, nil
292 }
293 case ArrayQueryFormatRepeat:
294 innerEncoder := e.typeEncoder(t.Elem())
295 return func(key string, value reflect.Value) (pairs []Pair, err error) {
296 for i := 0; i < value.Len(); i++ {
297 subpairs, suberr := innerEncoder(key, value.Index(i))
298 if suberr != nil {
299 err = suberr
300 }
301 pairs = append(pairs, subpairs...)
302 }
303 return
304 }
305 case ArrayQueryFormatIndices:
306 panic("The array indices format is not supported yet")
307 case ArrayQueryFormatBrackets:
308 innerEncoder := e.typeEncoder(t.Elem())
309 return func(key string, value reflect.Value) (pairs []Pair, err error) {
310 pairs = []Pair{}
311 for i := 0; i < value.Len(); i++ {
312 subpairs, suberr := innerEncoder(key+"[]", value.Index(i))
313 if suberr != nil {
314 err = suberr
315 }
316 pairs = append(pairs, subpairs...)
317 }
318 return
319 }
320 default:
321 panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat))
322 }
323}
324
325func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
326 switch t.Kind() {
327 case reflect.Pointer:
328 inner := t.Elem()
329
330 innerEncoder := e.newPrimitiveTypeEncoder(inner)
331 return func(key string, v reflect.Value) ([]Pair, error) {
332 if !v.IsValid() || v.IsNil() {
333 return nil, nil
334 }
335 return innerEncoder(key, v.Elem())
336 }
337 case reflect.String:
338 return func(key string, v reflect.Value) ([]Pair, error) {
339 return []Pair{{key, v.String()}}, nil
340 }
341 case reflect.Bool:
342 return func(key string, v reflect.Value) ([]Pair, error) {
343 if v.Bool() {
344 return []Pair{{key, "true"}}, nil
345 }
346 return []Pair{{key, "false"}}, nil
347 }
348 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
349 return func(key string, v reflect.Value) ([]Pair, error) {
350 return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil
351 }
352 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
353 return func(key string, v reflect.Value) ([]Pair, error) {
354 return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil
355 }
356 case reflect.Float32, reflect.Float64:
357 return func(key string, v reflect.Value) ([]Pair, error) {
358 return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil
359 }
360 case reflect.Complex64, reflect.Complex128:
361 bitSize := 64
362 if t.Kind() == reflect.Complex128 {
363 bitSize = 128
364 }
365 return func(key string, v reflect.Value) ([]Pair, error) {
366 return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil
367 }
368 default:
369 return func(key string, v reflect.Value) ([]Pair, error) {
370 return nil, nil
371 }
372 }
373}
374
375func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
376 f, _ := t.FieldByName("Value")
377 enc := e.typeEncoder(f.Type)
378
379 return func(key string, value reflect.Value) ([]Pair, error) {
380 present := value.FieldByName("Present")
381 if !present.Bool() {
382 return nil, nil
383 }
384 null := value.FieldByName("Null")
385 if null.Bool() {
386 return nil, fmt.Errorf("apiquery: field cannot be null")
387 }
388 raw := value.FieldByName("Raw")
389 if !raw.IsNil() {
390 return e.typeEncoder(raw.Type())(key, raw)
391 }
392 return enc(key, value.FieldByName("Value"))
393 }
394}
395
396func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc {
397 format := e.dateFormat
398 return func(key string, value reflect.Value) ([]Pair, error) {
399 return []Pair{{
400 key,
401 value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format),
402 }}, nil
403 }
404}
405
406func (e encoder) newInterfaceEncoder() encoderFunc {
407 return func(key string, value reflect.Value) ([]Pair, error) {
408 value = value.Elem()
409 if !value.IsValid() {
410 return nil, nil
411 }
412 return e.typeEncoder(value.Type())(key, value)
413 }
414
415}