encoder.go

  1package apiquery
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"reflect"
  7	"strconv"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go/packages/param"
 13)
 14
 15var encoders sync.Map // map[reflect.Type]encoderFunc
 16
 17type encoder struct {
 18	dateFormat string
 19	root       bool
 20	settings   QuerySettings
 21}
 22
 23type encoderFunc func(key string, value reflect.Value) ([]Pair, error)
 24
 25type encoderField struct {
 26	tag parsedStructTag
 27	fn  encoderFunc
 28	idx []int
 29}
 30
 31type encoderEntry struct {
 32	reflect.Type
 33	dateFormat string
 34	root       bool
 35	settings   QuerySettings
 36}
 37
 38type Pair struct {
 39	key   string
 40	value string
 41}
 42
 43func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 44	entry := encoderEntry{
 45		Type:       t,
 46		dateFormat: e.dateFormat,
 47		root:       e.root,
 48		settings:   e.settings,
 49	}
 50
 51	if fi, ok := encoders.Load(entry); ok {
 52		return fi.(encoderFunc)
 53	}
 54
 55	// To deal with recursive types, populate the map with an
 56	// indirect func before we build it. This type waits on the
 57	// real func (f) to be ready and then calls it. This indirect
 58	// func is only used for recursive types.
 59	var (
 60		wg sync.WaitGroup
 61		f  encoderFunc
 62	)
 63	wg.Add(1)
 64	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) {
 65		wg.Wait()
 66		return f(key, v)
 67	}))
 68	if loaded {
 69		return fi.(encoderFunc)
 70	}
 71
 72	// Compute the real encoder and replace the indirect func with it.
 73	f = e.newTypeEncoder(t)
 74	wg.Done()
 75	encoders.Store(entry, f)
 76	return f
 77}
 78
 79func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) {
 80	s, err := value.Interface().(json.Marshaler).MarshalJSON()
 81	if err != nil {
 82		return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err)
 83	}
 84	return []Pair{{key, string(s)}}, nil
 85}
 86
 87func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
 88	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
 89		return e.newTimeTypeEncoder(t)
 90	}
 91
 92	if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
 93		return e.newRichFieldTypeEncoder(t)
 94	}
 95
 96	if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
 97		return marshalerEncoder
 98	}
 99
100	e.root = false
101	switch t.Kind() {
102	case reflect.Pointer:
103		encoder := e.typeEncoder(t.Elem())
104		return func(key string, value reflect.Value) (pairs []Pair, err error) {
105			if !value.IsValid() || value.IsNil() {
106				return
107			}
108			return encoder(key, value.Elem())
109		}
110	case reflect.Struct:
111		return e.newStructTypeEncoder(t)
112	case reflect.Array:
113		fallthrough
114	case reflect.Slice:
115		return e.newArrayTypeEncoder(t)
116	case reflect.Map:
117		return e.newMapEncoder(t)
118	case reflect.Interface:
119		return e.newInterfaceEncoder()
120	default:
121		return e.newPrimitiveTypeEncoder(t)
122	}
123}
124
125func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
126	if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
127		return e.newRichFieldTypeEncoder(t)
128	}
129
130	for i := 0; i < t.NumField(); i++ {
131		if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
132			return e.newStructUnionTypeEncoder(t)
133		}
134	}
135
136	encoderFields := []encoderField{}
137
138	// This helper allows us to recursively collect field encoders into a flat
139	// array. The parameter `index` keeps track of the access patterns necessary
140	// to get to some field.
141	var collectEncoderFields func(r reflect.Type, index []int)
142	collectEncoderFields = func(r reflect.Type, index []int) {
143		for i := 0; i < r.NumField(); i++ {
144			idx := append(index, i)
145			field := t.FieldByIndex(idx)
146			if !field.IsExported() {
147				continue
148			}
149			// If this is an embedded struct, traverse one level deeper to extract
150			// the field and get their encoders as well.
151			if field.Anonymous {
152				collectEncoderFields(field.Type, idx)
153				continue
154			}
155			// If query tag is not present, then we skip, which is intentionally
156			// different behavior from the stdlib.
157			ptag, ok := parseQueryStructTag(field)
158			if !ok {
159				continue
160			}
161
162			if (ptag.name == "-" || ptag.name == "") && !ptag.inline {
163				continue
164			}
165
166			dateFormat, ok := parseFormatStructTag(field)
167			oldFormat := e.dateFormat
168			if ok {
169				switch dateFormat {
170				case "date-time":
171					e.dateFormat = time.RFC3339
172				case "date":
173					e.dateFormat = "2006-01-02"
174				}
175			}
176			var encoderFn encoderFunc
177			if ptag.omitzero {
178				typeEncoderFn := e.typeEncoder(field.Type)
179				encoderFn = func(key string, value reflect.Value) ([]Pair, error) {
180					if value.IsZero() {
181						return nil, nil
182					}
183					return typeEncoderFn(key, value)
184				}
185			} else {
186				encoderFn = e.typeEncoder(field.Type)
187			}
188			encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
189			e.dateFormat = oldFormat
190		}
191	}
192	collectEncoderFields(t, []int{})
193
194	return func(key string, value reflect.Value) (pairs []Pair, err error) {
195		for _, ef := range encoderFields {
196			var subkey string = e.renderKeyPath(key, ef.tag.name)
197			if ef.tag.inline {
198				subkey = key
199			}
200
201			field := value.FieldByIndex(ef.idx)
202			subpairs, suberr := ef.fn(subkey, field)
203			if suberr != nil {
204				err = suberr
205			}
206			pairs = append(pairs, subpairs...)
207		}
208		return
209	}
210}
211
212var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
213
214func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
215	var fieldEncoders []encoderFunc
216	for i := 0; i < t.NumField(); i++ {
217		field := t.Field(i)
218		if field.Type == paramUnionType && field.Anonymous {
219			fieldEncoders = append(fieldEncoders, nil)
220			continue
221		}
222		fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
223	}
224
225	return func(key string, value reflect.Value) (pairs []Pair, err error) {
226		for i := 0; i < t.NumField(); i++ {
227			if value.Field(i).Type() == paramUnionType {
228				continue
229			}
230			if !value.Field(i).IsZero() {
231				return fieldEncoders[i](key, value.Field(i))
232			}
233		}
234		return nil, fmt.Errorf("apiquery: union %s has no field set", t.String())
235	}
236}
237
238func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
239	keyEncoder := e.typeEncoder(t.Key())
240	elementEncoder := e.typeEncoder(t.Elem())
241	return func(key string, value reflect.Value) (pairs []Pair, err error) {
242		iter := value.MapRange()
243		for iter.Next() {
244			encodedKey, err := keyEncoder("", iter.Key())
245			if err != nil {
246				return nil, err
247			}
248			if len(encodedKey) != 1 {
249				return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive")
250			}
251			subkey := encodedKey[0].value
252			keyPath := e.renderKeyPath(key, subkey)
253			subpairs, suberr := elementEncoder(keyPath, iter.Value())
254			if suberr != nil {
255				err = suberr
256			}
257			pairs = append(pairs, subpairs...)
258		}
259		return
260	}
261}
262
263func (e *encoder) renderKeyPath(key string, subkey string) string {
264	if len(key) == 0 {
265		return subkey
266	}
267	if e.settings.NestedFormat == NestedQueryFormatDots {
268		return fmt.Sprintf("%s.%s", key, subkey)
269	}
270	return fmt.Sprintf("%s[%s]", key, subkey)
271}
272
273func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
274	switch e.settings.ArrayFormat {
275	case ArrayQueryFormatComma:
276		innerEncoder := e.typeEncoder(t.Elem())
277		return func(key string, v reflect.Value) ([]Pair, error) {
278			elements := []string{}
279			for i := 0; i < v.Len(); i++ {
280				innerPairs, err := innerEncoder("", v.Index(i))
281				if err != nil {
282					return nil, err
283				}
284				for _, pair := range innerPairs {
285					elements = append(elements, pair.value)
286				}
287			}
288			if len(elements) == 0 {
289				return []Pair{}, nil
290			}
291			return []Pair{{key, strings.Join(elements, ",")}}, nil
292		}
293	case ArrayQueryFormatRepeat:
294		innerEncoder := e.typeEncoder(t.Elem())
295		return func(key string, value reflect.Value) (pairs []Pair, err error) {
296			for i := 0; i < value.Len(); i++ {
297				subpairs, suberr := innerEncoder(key, value.Index(i))
298				if suberr != nil {
299					err = suberr
300				}
301				pairs = append(pairs, subpairs...)
302			}
303			return
304		}
305	case ArrayQueryFormatIndices:
306		panic("The array indices format is not supported yet")
307	case ArrayQueryFormatBrackets:
308		innerEncoder := e.typeEncoder(t.Elem())
309		return func(key string, value reflect.Value) (pairs []Pair, err error) {
310			pairs = []Pair{}
311			for i := 0; i < value.Len(); i++ {
312				subpairs, suberr := innerEncoder(key+"[]", value.Index(i))
313				if suberr != nil {
314					err = suberr
315				}
316				pairs = append(pairs, subpairs...)
317			}
318			return
319		}
320	default:
321		panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat))
322	}
323}
324
325func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
326	switch t.Kind() {
327	case reflect.Pointer:
328		inner := t.Elem()
329
330		innerEncoder := e.newPrimitiveTypeEncoder(inner)
331		return func(key string, v reflect.Value) ([]Pair, error) {
332			if !v.IsValid() || v.IsNil() {
333				return nil, nil
334			}
335			return innerEncoder(key, v.Elem())
336		}
337	case reflect.String:
338		return func(key string, v reflect.Value) ([]Pair, error) {
339			return []Pair{{key, v.String()}}, nil
340		}
341	case reflect.Bool:
342		return func(key string, v reflect.Value) ([]Pair, error) {
343			if v.Bool() {
344				return []Pair{{key, "true"}}, nil
345			}
346			return []Pair{{key, "false"}}, nil
347		}
348	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
349		return func(key string, v reflect.Value) ([]Pair, error) {
350			return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil
351		}
352	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
353		return func(key string, v reflect.Value) ([]Pair, error) {
354			return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil
355		}
356	case reflect.Float32, reflect.Float64:
357		return func(key string, v reflect.Value) ([]Pair, error) {
358			return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil
359		}
360	case reflect.Complex64, reflect.Complex128:
361		bitSize := 64
362		if t.Kind() == reflect.Complex128 {
363			bitSize = 128
364		}
365		return func(key string, v reflect.Value) ([]Pair, error) {
366			return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil
367		}
368	default:
369		return func(key string, v reflect.Value) ([]Pair, error) {
370			return nil, nil
371		}
372	}
373}
374
375func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
376	f, _ := t.FieldByName("Value")
377	enc := e.typeEncoder(f.Type)
378
379	return func(key string, value reflect.Value) ([]Pair, error) {
380		present := value.FieldByName("Present")
381		if !present.Bool() {
382			return nil, nil
383		}
384		null := value.FieldByName("Null")
385		if null.Bool() {
386			return nil, fmt.Errorf("apiquery: field cannot be null")
387		}
388		raw := value.FieldByName("Raw")
389		if !raw.IsNil() {
390			return e.typeEncoder(raw.Type())(key, raw)
391		}
392		return enc(key, value.FieldByName("Value"))
393	}
394}
395
396func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc {
397	format := e.dateFormat
398	return func(key string, value reflect.Value) ([]Pair, error) {
399		return []Pair{{
400			key,
401			value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format),
402		}}, nil
403	}
404}
405
406func (e encoder) newInterfaceEncoder() encoderFunc {
407	return func(key string, value reflect.Value) ([]Pair, error) {
408		value = value.Elem()
409		if !value.IsValid() {
410			return nil, nil
411		}
412		return e.typeEncoder(value.Type())(key, value)
413	}
414
415}