encoder.go

  1package apijson
  2
  3import (
  4	"bytes"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"sort"
  9	"strconv"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"github.com/tidwall/sjson"
 15
 16	"github.com/openai/openai-go/internal/param"
 17)
 18
 19var encoders sync.Map // map[encoderEntry]encoderFunc
 20
 21func Marshal(value interface{}) ([]byte, error) {
 22	e := &encoder{dateFormat: time.RFC3339}
 23	return e.marshal(value)
 24}
 25
 26func MarshalRoot(value interface{}) ([]byte, error) {
 27	e := &encoder{root: true, dateFormat: time.RFC3339}
 28	return e.marshal(value)
 29}
 30
 31type encoder struct {
 32	dateFormat string
 33	root       bool
 34}
 35
 36type encoderFunc func(value reflect.Value) ([]byte, error)
 37
 38type encoderField struct {
 39	tag parsedStructTag
 40	fn  encoderFunc
 41	idx []int
 42}
 43
 44type encoderEntry struct {
 45	reflect.Type
 46	dateFormat string
 47	root       bool
 48}
 49
 50func (e *encoder) marshal(value interface{}) ([]byte, error) {
 51	val := reflect.ValueOf(value)
 52	if !val.IsValid() {
 53		return nil, nil
 54	}
 55	typ := val.Type()
 56	enc := e.typeEncoder(typ)
 57	return enc(val)
 58}
 59
 60func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 61	entry := encoderEntry{
 62		Type:       t,
 63		dateFormat: e.dateFormat,
 64		root:       e.root,
 65	}
 66
 67	if fi, ok := encoders.Load(entry); ok {
 68		return fi.(encoderFunc)
 69	}
 70
 71	// To deal with recursive types, populate the map with an
 72	// indirect func before we build it. This type waits on the
 73	// real func (f) to be ready and then calls it. This indirect
 74	// func is only used for recursive types.
 75	var (
 76		wg sync.WaitGroup
 77		f  encoderFunc
 78	)
 79	wg.Add(1)
 80	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) {
 81		wg.Wait()
 82		return f(v)
 83	}))
 84	if loaded {
 85		return fi.(encoderFunc)
 86	}
 87
 88	// Compute the real encoder and replace the indirect func with it.
 89	f = e.newTypeEncoder(t)
 90	wg.Done()
 91	encoders.Store(entry, f)
 92	return f
 93}
 94
 95func marshalerEncoder(v reflect.Value) ([]byte, error) {
 96	return v.Interface().(json.Marshaler).MarshalJSON()
 97}
 98
 99func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) {
100	return v.Addr().Interface().(json.Marshaler).MarshalJSON()
101}
102
103func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
104	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
105		return e.newTimeTypeEncoder()
106	}
107	if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
108		return marshalerEncoder
109	}
110	if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
111		return indirectMarshalerEncoder
112	}
113	e.root = false
114	switch t.Kind() {
115	case reflect.Pointer:
116		inner := t.Elem()
117
118		innerEncoder := e.typeEncoder(inner)
119		return func(v reflect.Value) ([]byte, error) {
120			if !v.IsValid() || v.IsNil() {
121				return nil, nil
122			}
123			return innerEncoder(v.Elem())
124		}
125	case reflect.Struct:
126		return e.newStructTypeEncoder(t)
127	case reflect.Array:
128		fallthrough
129	case reflect.Slice:
130		return e.newArrayTypeEncoder(t)
131	case reflect.Map:
132		return e.newMapEncoder(t)
133	case reflect.Interface:
134		return e.newInterfaceEncoder()
135	default:
136		return e.newPrimitiveTypeEncoder(t)
137	}
138}
139
140func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
141	switch t.Kind() {
142	// Note that we could use `gjson` to encode these types but it would complicate our
143	// code more and this current code shouldn't cause any issues
144	case reflect.String:
145		return func(v reflect.Value) ([]byte, error) {
146			return json.Marshal(v.Interface())
147		}
148	case reflect.Bool:
149		return func(v reflect.Value) ([]byte, error) {
150			if v.Bool() {
151				return []byte("true"), nil
152			}
153			return []byte("false"), nil
154		}
155	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
156		return func(v reflect.Value) ([]byte, error) {
157			return []byte(strconv.FormatInt(v.Int(), 10)), nil
158		}
159	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
160		return func(v reflect.Value) ([]byte, error) {
161			return []byte(strconv.FormatUint(v.Uint(), 10)), nil
162		}
163	case reflect.Float32:
164		return func(v reflect.Value) ([]byte, error) {
165			return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil
166		}
167	case reflect.Float64:
168		return func(v reflect.Value) ([]byte, error) {
169			return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil
170		}
171	default:
172		return func(v reflect.Value) ([]byte, error) {
173			return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
174		}
175	}
176}
177
178func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
179	itemEncoder := e.typeEncoder(t.Elem())
180
181	return func(value reflect.Value) ([]byte, error) {
182		json := []byte("[]")
183		for i := 0; i < value.Len(); i++ {
184			var value, err = itemEncoder(value.Index(i))
185			if err != nil {
186				return nil, err
187			}
188			if value == nil {
189				// Assume that empty items should be inserted as `null` so that the output array
190				// will be the same length as the input array
191				value = []byte("null")
192			}
193
194			json, err = sjson.SetRawBytes(json, "-1", value)
195			if err != nil {
196				return nil, err
197			}
198		}
199
200		return json, nil
201	}
202}
203
204func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
205	if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) {
206		return e.newFieldTypeEncoder(t)
207	}
208
209	encoderFields := []encoderField{}
210	extraEncoder := (*encoderField)(nil)
211
212	// This helper allows us to recursively collect field encoders into a flat
213	// array. The parameter `index` keeps track of the access patterns necessary
214	// to get to some field.
215	var collectEncoderFields func(r reflect.Type, index []int)
216	collectEncoderFields = func(r reflect.Type, index []int) {
217		for i := 0; i < r.NumField(); i++ {
218			idx := append(index, i)
219			field := t.FieldByIndex(idx)
220			if !field.IsExported() {
221				continue
222			}
223			// If this is an embedded struct, traverse one level deeper to extract
224			// the field and get their encoders as well.
225			if field.Anonymous {
226				collectEncoderFields(field.Type, idx)
227				continue
228			}
229			// If json tag is not present, then we skip, which is intentionally
230			// different behavior from the stdlib.
231			ptag, ok := parseJSONStructTag(field)
232			if !ok {
233				continue
234			}
235			// We only want to support unexported field if they're tagged with
236			// `extras` because that field shouldn't be part of the public API. We
237			// also want to only keep the top level extras
238			if ptag.extras && len(index) == 0 {
239				extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
240				continue
241			}
242			if ptag.name == "-" {
243				continue
244			}
245
246			dateFormat, ok := parseFormatStructTag(field)
247			oldFormat := e.dateFormat
248			if ok {
249				switch dateFormat {
250				case "date-time":
251					e.dateFormat = time.RFC3339
252				case "date":
253					e.dateFormat = "2006-01-02"
254				}
255			}
256			encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
257			e.dateFormat = oldFormat
258		}
259	}
260	collectEncoderFields(t, []int{})
261
262	// Ensure deterministic output by sorting by lexicographic order
263	sort.Slice(encoderFields, func(i, j int) bool {
264		return encoderFields[i].tag.name < encoderFields[j].tag.name
265	})
266
267	return func(value reflect.Value) (json []byte, err error) {
268		json = []byte("{}")
269
270		for _, ef := range encoderFields {
271			field := value.FieldByIndex(ef.idx)
272			encoded, err := ef.fn(field)
273			if err != nil {
274				return nil, err
275			}
276			if encoded == nil {
277				continue
278			}
279			json, err = sjson.SetRawBytes(json, ef.tag.name, encoded)
280			if err != nil {
281				return nil, err
282			}
283		}
284
285		if extraEncoder != nil {
286			json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx))
287			if err != nil {
288				return nil, err
289			}
290		}
291		return
292	}
293}
294
295func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
296	f, _ := t.FieldByName("Value")
297	enc := e.typeEncoder(f.Type)
298
299	return func(value reflect.Value) (json []byte, err error) {
300		present := value.FieldByName("Present")
301		if !present.Bool() {
302			return nil, nil
303		}
304		null := value.FieldByName("Null")
305		if null.Bool() {
306			return []byte("null"), nil
307		}
308		raw := value.FieldByName("Raw")
309		if !raw.IsNil() {
310			return e.typeEncoder(raw.Type())(raw)
311		}
312		return enc(value.FieldByName("Value"))
313	}
314}
315
316func (e *encoder) newTimeTypeEncoder() encoderFunc {
317	format := e.dateFormat
318	return func(value reflect.Value) (json []byte, err error) {
319		return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil
320	}
321}
322
323func (e encoder) newInterfaceEncoder() encoderFunc {
324	return func(value reflect.Value) ([]byte, error) {
325		value = value.Elem()
326		if !value.IsValid() {
327			return nil, nil
328		}
329		return e.typeEncoder(value.Type())(value)
330	}
331}
332
333// Given a []byte of json (may either be an empty object or an object that already contains entries)
334// encode all of the entries in the map to the json byte array.
335func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) {
336	type mapPair struct {
337		key   []byte
338		value reflect.Value
339	}
340
341	pairs := []mapPair{}
342	keyEncoder := e.typeEncoder(v.Type().Key())
343
344	iter := v.MapRange()
345	for iter.Next() {
346		var encodedKeyString string
347		if iter.Key().Type().Kind() == reflect.String {
348			encodedKeyString = iter.Key().String()
349		} else {
350			var err error
351			encodedKeyBytes, err := keyEncoder(iter.Key())
352			if err != nil {
353				return nil, err
354			}
355			encodedKeyString = string(encodedKeyBytes)
356		}
357		encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString))
358		pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()})
359	}
360
361	// Ensure deterministic output
362	sort.Slice(pairs, func(i, j int) bool {
363		return bytes.Compare(pairs[i].key, pairs[j].key) < 0
364	})
365
366	elementEncoder := e.typeEncoder(v.Type().Elem())
367	for _, p := range pairs {
368		encodedValue, err := elementEncoder(p.value)
369		if err != nil {
370			return nil, err
371		}
372		if len(encodedValue) == 0 {
373			continue
374		}
375		json, err = sjson.SetRawBytes(json, string(p.key), encodedValue)
376		if err != nil {
377			return nil, err
378		}
379	}
380
381	return json, nil
382}
383
384func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
385	return func(value reflect.Value) ([]byte, error) {
386		json := []byte("{}")
387		var err error
388		json, err = e.encodeMapEntries(json, value)
389		if err != nil {
390			return nil, err
391		}
392		return json, nil
393	}
394}
395
396// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have
397// special characters that sjson interprets as a path.
398var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*")