1package apijson
  2
  3import (
  4	"bytes"
  5	"encoding/json"
  6	"fmt"
  7	"reflect"
  8	"sort"
  9	"strconv"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"github.com/tidwall/sjson"
 15)
 16
 17var encoders sync.Map // map[encoderEntry]encoderFunc
 18
 19func Marshal(value any) ([]byte, error) {
 20	e := &encoder{dateFormat: time.RFC3339}
 21	return e.marshal(value)
 22}
 23
 24func MarshalRoot(value any) ([]byte, error) {
 25	e := &encoder{root: true, dateFormat: time.RFC3339}
 26	return e.marshal(value)
 27}
 28
 29type encoder struct {
 30	dateFormat string
 31	root       bool
 32}
 33
 34type encoderFunc func(value reflect.Value) ([]byte, error)
 35
 36type encoderField struct {
 37	tag parsedStructTag
 38	fn  encoderFunc
 39	idx []int
 40}
 41
 42type encoderEntry struct {
 43	reflect.Type
 44	dateFormat string
 45	root       bool
 46}
 47
 48func (e *encoder) marshal(value any) ([]byte, error) {
 49	val := reflect.ValueOf(value)
 50	if !val.IsValid() {
 51		return nil, nil
 52	}
 53	typ := val.Type()
 54	enc := e.typeEncoder(typ)
 55	return enc(val)
 56}
 57
 58func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 59	entry := encoderEntry{
 60		Type:       t,
 61		dateFormat: e.dateFormat,
 62		root:       e.root,
 63	}
 64
 65	if fi, ok := encoders.Load(entry); ok {
 66		return fi.(encoderFunc)
 67	}
 68
 69	// To deal with recursive types, populate the map with an
 70	// indirect func before we build it. This type waits on the
 71	// real func (f) to be ready and then calls it. This indirect
 72	// func is only used for recursive types.
 73	var (
 74		wg sync.WaitGroup
 75		f  encoderFunc
 76	)
 77	wg.Add(1)
 78	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) {
 79		wg.Wait()
 80		return f(v)
 81	}))
 82	if loaded {
 83		return fi.(encoderFunc)
 84	}
 85
 86	// Compute the real encoder and replace the indirect func with it.
 87	f = e.newTypeEncoder(t)
 88	wg.Done()
 89	encoders.Store(entry, f)
 90	return f
 91}
 92
 93func marshalerEncoder(v reflect.Value) ([]byte, error) {
 94	return v.Interface().(json.Marshaler).MarshalJSON()
 95}
 96
 97func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) {
 98	return v.Addr().Interface().(json.Marshaler).MarshalJSON()
 99}
100
101func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
102	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
103		return e.newTimeTypeEncoder()
104	}
105	if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
106		return marshalerEncoder
107	}
108	if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
109		return indirectMarshalerEncoder
110	}
111	e.root = false
112	switch t.Kind() {
113	case reflect.Pointer:
114		inner := t.Elem()
115
116		innerEncoder := e.typeEncoder(inner)
117		return func(v reflect.Value) ([]byte, error) {
118			if !v.IsValid() || v.IsNil() {
119				return nil, nil
120			}
121			return innerEncoder(v.Elem())
122		}
123	case reflect.Struct:
124		return e.newStructTypeEncoder(t)
125	case reflect.Array:
126		fallthrough
127	case reflect.Slice:
128		return e.newArrayTypeEncoder(t)
129	case reflect.Map:
130		return e.newMapEncoder(t)
131	case reflect.Interface:
132		return e.newInterfaceEncoder()
133	default:
134		return e.newPrimitiveTypeEncoder(t)
135	}
136}
137
138func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
139	switch t.Kind() {
140	// Note that we could use `gjson` to encode these types but it would complicate our
141	// code more and this current code shouldn't cause any issues
142	case reflect.String:
143		return func(v reflect.Value) ([]byte, error) {
144			return json.Marshal(v.Interface())
145		}
146	case reflect.Bool:
147		return func(v reflect.Value) ([]byte, error) {
148			if v.Bool() {
149				return []byte("true"), nil
150			}
151			return []byte("false"), nil
152		}
153	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
154		return func(v reflect.Value) ([]byte, error) {
155			return []byte(strconv.FormatInt(v.Int(), 10)), nil
156		}
157	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
158		return func(v reflect.Value) ([]byte, error) {
159			return []byte(strconv.FormatUint(v.Uint(), 10)), nil
160		}
161	case reflect.Float32:
162		return func(v reflect.Value) ([]byte, error) {
163			return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil
164		}
165	case reflect.Float64:
166		return func(v reflect.Value) ([]byte, error) {
167			return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil
168		}
169	default:
170		return func(v reflect.Value) ([]byte, error) {
171			return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
172		}
173	}
174}
175
176func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
177	itemEncoder := e.typeEncoder(t.Elem())
178
179	return func(value reflect.Value) ([]byte, error) {
180		json := []byte("[]")
181		for i := 0; i < value.Len(); i++ {
182			var value, err = itemEncoder(value.Index(i))
183			if err != nil {
184				return nil, err
185			}
186			if value == nil {
187				// Assume that empty items should be inserted as `null` so that the output array
188				// will be the same length as the input array
189				value = []byte("null")
190			}
191
192			json, err = sjson.SetRawBytes(json, "-1", value)
193			if err != nil {
194				return nil, err
195			}
196		}
197
198		return json, nil
199	}
200}
201
202func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
203	encoderFields := []encoderField{}
204	extraEncoder := (*encoderField)(nil)
205
206	// This helper allows us to recursively collect field encoders into a flat
207	// array. The parameter `index` keeps track of the access patterns necessary
208	// to get to some field.
209	var collectEncoderFields func(r reflect.Type, index []int)
210	collectEncoderFields = func(r reflect.Type, index []int) {
211		for i := 0; i < r.NumField(); i++ {
212			idx := append(index, i)
213			field := t.FieldByIndex(idx)
214			if !field.IsExported() {
215				continue
216			}
217			// If this is an embedded struct, traverse one level deeper to extract
218			// the field and get their encoders as well.
219			if field.Anonymous {
220				collectEncoderFields(field.Type, idx)
221				continue
222			}
223			// If json tag is not present, then we skip, which is intentionally
224			// different behavior from the stdlib.
225			ptag, ok := parseJSONStructTag(field)
226			if !ok {
227				continue
228			}
229			// We only want to support unexported field if they're tagged with
230			// `extras` because that field shouldn't be part of the public API. We
231			// also want to only keep the top level extras
232			if ptag.extras && len(index) == 0 {
233				extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
234				continue
235			}
236			if ptag.name == "-" {
237				continue
238			}
239
240			dateFormat, ok := parseFormatStructTag(field)
241			oldFormat := e.dateFormat
242			if ok {
243				switch dateFormat {
244				case "date-time":
245					e.dateFormat = time.RFC3339
246				case "date":
247					e.dateFormat = "2006-01-02"
248				}
249			}
250			encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
251			e.dateFormat = oldFormat
252		}
253	}
254	collectEncoderFields(t, []int{})
255
256	// Ensure deterministic output by sorting by lexicographic order
257	sort.Slice(encoderFields, func(i, j int) bool {
258		return encoderFields[i].tag.name < encoderFields[j].tag.name
259	})
260
261	return func(value reflect.Value) (json []byte, err error) {
262		json = []byte("{}")
263
264		for _, ef := range encoderFields {
265			field := value.FieldByIndex(ef.idx)
266			encoded, err := ef.fn(field)
267			if err != nil {
268				return nil, err
269			}
270			if encoded == nil {
271				continue
272			}
273			json, err = sjson.SetRawBytes(json, ef.tag.name, encoded)
274			if err != nil {
275				return nil, err
276			}
277		}
278
279		if extraEncoder != nil {
280			json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx))
281			if err != nil {
282				return nil, err
283			}
284		}
285		return
286	}
287}
288
289func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
290	f, _ := t.FieldByName("Value")
291	enc := e.typeEncoder(f.Type)
292
293	return func(value reflect.Value) (json []byte, err error) {
294		present := value.FieldByName("Present")
295		if !present.Bool() {
296			return nil, nil
297		}
298		null := value.FieldByName("Null")
299		if null.Bool() {
300			return []byte("null"), nil
301		}
302		raw := value.FieldByName("Raw")
303		if !raw.IsNil() {
304			return e.typeEncoder(raw.Type())(raw)
305		}
306		return enc(value.FieldByName("Value"))
307	}
308}
309
310func (e *encoder) newTimeTypeEncoder() encoderFunc {
311	format := e.dateFormat
312	return func(value reflect.Value) (json []byte, err error) {
313		return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil
314	}
315}
316
317func (e encoder) newInterfaceEncoder() encoderFunc {
318	return func(value reflect.Value) ([]byte, error) {
319		value = value.Elem()
320		if !value.IsValid() {
321			return nil, nil
322		}
323		return e.typeEncoder(value.Type())(value)
324	}
325}
326
327// Given a []byte of json (may either be an empty object or an object that already contains entries)
328// encode all of the entries in the map to the json byte array.
329func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) {
330	type mapPair struct {
331		key   []byte
332		value reflect.Value
333	}
334
335	pairs := []mapPair{}
336	keyEncoder := e.typeEncoder(v.Type().Key())
337
338	iter := v.MapRange()
339	for iter.Next() {
340		var encodedKeyString string
341		if iter.Key().Type().Kind() == reflect.String {
342			encodedKeyString = iter.Key().String()
343		} else {
344			var err error
345			encodedKeyBytes, err := keyEncoder(iter.Key())
346			if err != nil {
347				return nil, err
348			}
349			encodedKeyString = string(encodedKeyBytes)
350		}
351		encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString))
352		pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()})
353	}
354
355	// Ensure deterministic output
356	sort.Slice(pairs, func(i, j int) bool {
357		return bytes.Compare(pairs[i].key, pairs[j].key) < 0
358	})
359
360	elementEncoder := e.typeEncoder(v.Type().Elem())
361	for _, p := range pairs {
362		encodedValue, err := elementEncoder(p.value)
363		if err != nil {
364			return nil, err
365		}
366		if len(encodedValue) == 0 {
367			continue
368		}
369		json, err = sjson.SetRawBytes(json, string(p.key), encodedValue)
370		if err != nil {
371			return nil, err
372		}
373	}
374
375	return json, nil
376}
377
378func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
379	return func(value reflect.Value) ([]byte, error) {
380		json := []byte("{}")
381		var err error
382		json, err = e.encodeMapEntries(json, value)
383		if err != nil {
384			return nil, err
385		}
386		return json, nil
387	}
388}
389
390// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have
391// special characters that sjson interprets as a path.
392var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*")