encoder.go

  1package apiquery
  2
  3import (
  4	"encoding/json"
  5	"fmt"
  6	"reflect"
  7	"strconv"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	internalparam "github.com/openai/openai-go/internal/param"
 13	"github.com/openai/openai-go/packages/param"
 14)
 15
 16var encoders sync.Map // map[reflect.Type]encoderFunc
 17
 18type encoder struct {
 19	dateFormat string
 20	root       bool
 21	settings   QuerySettings
 22}
 23
 24type encoderFunc func(key string, value reflect.Value) []Pair
 25
 26type encoderField struct {
 27	tag parsedStructTag
 28	fn  encoderFunc
 29	idx []int
 30}
 31
 32type encoderEntry struct {
 33	reflect.Type
 34	dateFormat string
 35	root       bool
 36	settings   QuerySettings
 37}
 38
 39type Pair struct {
 40	key   string
 41	value string
 42}
 43
 44func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 45	entry := encoderEntry{
 46		Type:       t,
 47		dateFormat: e.dateFormat,
 48		root:       e.root,
 49		settings:   e.settings,
 50	}
 51
 52	if fi, ok := encoders.Load(entry); ok {
 53		return fi.(encoderFunc)
 54	}
 55
 56	// To deal with recursive types, populate the map with an
 57	// indirect func before we build it. This type waits on the
 58	// real func (f) to be ready and then calls it. This indirect
 59	// func is only used for recursive types.
 60	var (
 61		wg sync.WaitGroup
 62		f  encoderFunc
 63	)
 64	wg.Add(1)
 65	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) []Pair {
 66		wg.Wait()
 67		return f(key, v)
 68	}))
 69	if loaded {
 70		return fi.(encoderFunc)
 71	}
 72
 73	// Compute the real encoder and replace the indirect func with it.
 74	f = e.newTypeEncoder(t)
 75	wg.Done()
 76	encoders.Store(entry, f)
 77	return f
 78}
 79
 80func marshalerEncoder(key string, value reflect.Value) []Pair {
 81	s, _ := value.Interface().(json.Marshaler).MarshalJSON()
 82	return []Pair{{key, string(s)}}
 83}
 84
 85func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
 86	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
 87		return e.newTimeTypeEncoder(t)
 88	}
 89
 90	if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) && param.OptionalPrimitiveTypes[t] == nil {
 91		return marshalerEncoder
 92	}
 93	e.root = false
 94	switch t.Kind() {
 95	case reflect.Pointer:
 96		encoder := e.typeEncoder(t.Elem())
 97		return func(key string, value reflect.Value) (pairs []Pair) {
 98			if !value.IsValid() || value.IsNil() {
 99				return
100			}
101			pairs = encoder(key, value.Elem())
102			return
103		}
104	case reflect.Struct:
105		return e.newStructTypeEncoder(t)
106	case reflect.Array:
107		fallthrough
108	case reflect.Slice:
109		return e.newArrayTypeEncoder(t)
110	case reflect.Map:
111		return e.newMapEncoder(t)
112	case reflect.Interface:
113		return e.newInterfaceEncoder()
114	default:
115		return e.newPrimitiveTypeEncoder(t)
116	}
117}
118
119func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
120	if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
121		return e.newFieldTypeEncoder(t)
122	}
123
124	if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
125		return e.newRichFieldTypeEncoder(t, idx)
126	}
127
128	encoderFields := []encoderField{}
129
130	// This helper allows us to recursively collect field encoders into a flat
131	// array. The parameter `index` keeps track of the access patterns necessary
132	// to get to some field.
133	var collectEncoderFields func(r reflect.Type, index []int)
134	collectEncoderFields = func(r reflect.Type, index []int) {
135		for i := 0; i < r.NumField(); i++ {
136			idx := append(index, i)
137			field := t.FieldByIndex(idx)
138			if !field.IsExported() {
139				continue
140			}
141			// If this is an embedded struct, traverse one level deeper to extract
142			// the field and get their encoders as well.
143			if field.Anonymous {
144				collectEncoderFields(field.Type, idx)
145				continue
146			}
147			// If query tag is not present, then we skip, which is intentionally
148			// different behavior from the stdlib.
149			ptag, ok := parseQueryStructTag(field)
150			if !ok {
151				continue
152			}
153
154			if (ptag.name == "-" || ptag.name == "") && !ptag.inline {
155				continue
156			}
157
158			dateFormat, ok := parseFormatStructTag(field)
159			oldFormat := e.dateFormat
160			if ok {
161				switch dateFormat {
162				case "date-time":
163					e.dateFormat = time.RFC3339
164				case "date":
165					e.dateFormat = "2006-01-02"
166				}
167			}
168			var encoderFn encoderFunc
169			if ptag.omitzero {
170				typeEncoderFn := e.typeEncoder(field.Type)
171				encoderFn = func(key string, value reflect.Value) []Pair {
172					if value.IsZero() {
173						return nil
174					}
175					return typeEncoderFn(key, value)
176				}
177			} else {
178				encoderFn = e.typeEncoder(field.Type)
179			}
180			encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
181			e.dateFormat = oldFormat
182		}
183	}
184	collectEncoderFields(t, []int{})
185
186	return func(key string, value reflect.Value) (pairs []Pair) {
187		for _, ef := range encoderFields {
188			var subkey string = e.renderKeyPath(key, ef.tag.name)
189			if ef.tag.inline {
190				subkey = key
191			}
192
193			field := value.FieldByIndex(ef.idx)
194			pairs = append(pairs, ef.fn(subkey, field)...)
195		}
196		return
197	}
198}
199
200func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
201	keyEncoder := e.typeEncoder(t.Key())
202	elementEncoder := e.typeEncoder(t.Elem())
203	return func(key string, value reflect.Value) (pairs []Pair) {
204		iter := value.MapRange()
205		for iter.Next() {
206			encodedKey := keyEncoder("", iter.Key())
207			if len(encodedKey) != 1 {
208				panic("Unexpected number of parts for encoded map key. Are you using a non-primitive for this map?")
209			}
210			subkey := encodedKey[0].value
211			keyPath := e.renderKeyPath(key, subkey)
212			pairs = append(pairs, elementEncoder(keyPath, iter.Value())...)
213		}
214		return
215	}
216}
217
218func (e *encoder) renderKeyPath(key string, subkey string) string {
219	if len(key) == 0 {
220		return subkey
221	}
222	if e.settings.NestedFormat == NestedQueryFormatDots {
223		return fmt.Sprintf("%s.%s", key, subkey)
224	}
225	return fmt.Sprintf("%s[%s]", key, subkey)
226}
227
228func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
229	switch e.settings.ArrayFormat {
230	case ArrayQueryFormatComma:
231		innerEncoder := e.typeEncoder(t.Elem())
232		return func(key string, v reflect.Value) []Pair {
233			elements := []string{}
234			for i := 0; i < v.Len(); i++ {
235				for _, pair := range innerEncoder("", v.Index(i)) {
236					elements = append(elements, pair.value)
237				}
238			}
239			if len(elements) == 0 {
240				return []Pair{}
241			}
242			return []Pair{{key, strings.Join(elements, ",")}}
243		}
244	case ArrayQueryFormatRepeat:
245		innerEncoder := e.typeEncoder(t.Elem())
246		return func(key string, value reflect.Value) (pairs []Pair) {
247			for i := 0; i < value.Len(); i++ {
248				pairs = append(pairs, innerEncoder(key, value.Index(i))...)
249			}
250			return pairs
251		}
252	case ArrayQueryFormatIndices:
253		panic("The array indices format is not supported yet")
254	case ArrayQueryFormatBrackets:
255		innerEncoder := e.typeEncoder(t.Elem())
256		return func(key string, value reflect.Value) []Pair {
257			pairs := []Pair{}
258			for i := 0; i < value.Len(); i++ {
259				pairs = append(pairs, innerEncoder(key+"[]", value.Index(i))...)
260			}
261			return pairs
262		}
263	default:
264		panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat))
265	}
266}
267
268func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
269	switch t.Kind() {
270	case reflect.Pointer:
271		inner := t.Elem()
272
273		innerEncoder := e.newPrimitiveTypeEncoder(inner)
274		return func(key string, v reflect.Value) []Pair {
275			if !v.IsValid() || v.IsNil() {
276				return nil
277			}
278			return innerEncoder(key, v.Elem())
279		}
280	case reflect.String:
281		return func(key string, v reflect.Value) []Pair {
282			return []Pair{{key, v.String()}}
283		}
284	case reflect.Bool:
285		return func(key string, v reflect.Value) []Pair {
286			if v.Bool() {
287				return []Pair{{key, "true"}}
288			}
289			return []Pair{{key, "false"}}
290		}
291	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
292		return func(key string, v reflect.Value) []Pair {
293			return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}
294		}
295	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
296		return func(key string, v reflect.Value) []Pair {
297			return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}
298		}
299	case reflect.Float32, reflect.Float64:
300		return func(key string, v reflect.Value) []Pair {
301			return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}
302		}
303	case reflect.Complex64, reflect.Complex128:
304		bitSize := 64
305		if t.Kind() == reflect.Complex128 {
306			bitSize = 128
307		}
308		return func(key string, v reflect.Value) []Pair {
309			return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}
310		}
311	default:
312		return func(key string, v reflect.Value) []Pair {
313			return nil
314		}
315	}
316}
317
318func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
319	f, _ := t.FieldByName("Value")
320	enc := e.typeEncoder(f.Type)
321
322	return func(key string, value reflect.Value) []Pair {
323		present := value.FieldByName("Present")
324		if !present.Bool() {
325			return nil
326		}
327		null := value.FieldByName("Null")
328		if null.Bool() {
329			// TODO: Error?
330			return nil
331		}
332		raw := value.FieldByName("Raw")
333		if !raw.IsNil() {
334			return e.typeEncoder(raw.Type())(key, raw)
335		}
336		return enc(key, value.FieldByName("Value"))
337	}
338}
339
340func (e *encoder) newTimeTypeEncoder(t reflect.Type) encoderFunc {
341	format := e.dateFormat
342	return func(key string, value reflect.Value) []Pair {
343		return []Pair{{
344			key,
345			value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format),
346		}}
347	}
348}
349
350func (e encoder) newInterfaceEncoder() encoderFunc {
351	return func(key string, value reflect.Value) []Pair {
352		value = value.Elem()
353		if !value.IsValid() {
354			return nil
355		}
356		return e.typeEncoder(value.Type())(key, value)
357	}
358
359}