1package apijson
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "sort"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/tidwall/sjson"
15)
16
17var encoders sync.Map // map[encoderEntry]encoderFunc
18
19func Marshal(value any) ([]byte, error) {
20 e := &encoder{dateFormat: time.RFC3339}
21 return e.marshal(value)
22}
23
24func MarshalRoot(value any) ([]byte, error) {
25 e := &encoder{root: true, dateFormat: time.RFC3339}
26 return e.marshal(value)
27}
28
29type encoder struct {
30 dateFormat string
31 root bool
32}
33
34type encoderFunc func(value reflect.Value) ([]byte, error)
35
36type encoderField struct {
37 tag parsedStructTag
38 fn encoderFunc
39 idx []int
40}
41
42type encoderEntry struct {
43 reflect.Type
44 dateFormat string
45 root bool
46}
47
48func (e *encoder) marshal(value any) ([]byte, error) {
49 val := reflect.ValueOf(value)
50 if !val.IsValid() {
51 return nil, nil
52 }
53 typ := val.Type()
54 enc := e.typeEncoder(typ)
55 return enc(val)
56}
57
58func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
59 entry := encoderEntry{
60 Type: t,
61 dateFormat: e.dateFormat,
62 root: e.root,
63 }
64
65 if fi, ok := encoders.Load(entry); ok {
66 return fi.(encoderFunc)
67 }
68
69 // To deal with recursive types, populate the map with an
70 // indirect func before we build it. This type waits on the
71 // real func (f) to be ready and then calls it. This indirect
72 // func is only used for recursive types.
73 var (
74 wg sync.WaitGroup
75 f encoderFunc
76 )
77 wg.Add(1)
78 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) {
79 wg.Wait()
80 return f(v)
81 }))
82 if loaded {
83 return fi.(encoderFunc)
84 }
85
86 // Compute the real encoder and replace the indirect func with it.
87 f = e.newTypeEncoder(t)
88 wg.Done()
89 encoders.Store(entry, f)
90 return f
91}
92
93func marshalerEncoder(v reflect.Value) ([]byte, error) {
94 return v.Interface().(json.Marshaler).MarshalJSON()
95}
96
97func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) {
98 return v.Addr().Interface().(json.Marshaler).MarshalJSON()
99}
100
101func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
102 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
103 return e.newTimeTypeEncoder()
104 }
105 if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
106 return marshalerEncoder
107 }
108 if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
109 return indirectMarshalerEncoder
110 }
111 e.root = false
112 switch t.Kind() {
113 case reflect.Pointer:
114 inner := t.Elem()
115
116 innerEncoder := e.typeEncoder(inner)
117 return func(v reflect.Value) ([]byte, error) {
118 if !v.IsValid() || v.IsNil() {
119 return nil, nil
120 }
121 return innerEncoder(v.Elem())
122 }
123 case reflect.Struct:
124 return e.newStructTypeEncoder(t)
125 case reflect.Array:
126 fallthrough
127 case reflect.Slice:
128 return e.newArrayTypeEncoder(t)
129 case reflect.Map:
130 return e.newMapEncoder(t)
131 case reflect.Interface:
132 return e.newInterfaceEncoder()
133 default:
134 return e.newPrimitiveTypeEncoder(t)
135 }
136}
137
138func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
139 switch t.Kind() {
140 // Note that we could use `gjson` to encode these types but it would complicate our
141 // code more and this current code shouldn't cause any issues
142 case reflect.String:
143 return func(v reflect.Value) ([]byte, error) {
144 return json.Marshal(v.Interface())
145 }
146 case reflect.Bool:
147 return func(v reflect.Value) ([]byte, error) {
148 if v.Bool() {
149 return []byte("true"), nil
150 }
151 return []byte("false"), nil
152 }
153 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
154 return func(v reflect.Value) ([]byte, error) {
155 return []byte(strconv.FormatInt(v.Int(), 10)), nil
156 }
157 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
158 return func(v reflect.Value) ([]byte, error) {
159 return []byte(strconv.FormatUint(v.Uint(), 10)), nil
160 }
161 case reflect.Float32:
162 return func(v reflect.Value) ([]byte, error) {
163 return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil
164 }
165 case reflect.Float64:
166 return func(v reflect.Value) ([]byte, error) {
167 return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil
168 }
169 default:
170 return func(v reflect.Value) ([]byte, error) {
171 return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
172 }
173 }
174}
175
176func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
177 itemEncoder := e.typeEncoder(t.Elem())
178
179 return func(value reflect.Value) ([]byte, error) {
180 json := []byte("[]")
181 for i := 0; i < value.Len(); i++ {
182 var value, err = itemEncoder(value.Index(i))
183 if err != nil {
184 return nil, err
185 }
186 if value == nil {
187 // Assume that empty items should be inserted as `null` so that the output array
188 // will be the same length as the input array
189 value = []byte("null")
190 }
191
192 json, err = sjson.SetRawBytes(json, "-1", value)
193 if err != nil {
194 return nil, err
195 }
196 }
197
198 return json, nil
199 }
200}
201
202func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
203 encoderFields := []encoderField{}
204 extraEncoder := (*encoderField)(nil)
205
206 // This helper allows us to recursively collect field encoders into a flat
207 // array. The parameter `index` keeps track of the access patterns necessary
208 // to get to some field.
209 var collectEncoderFields func(r reflect.Type, index []int)
210 collectEncoderFields = func(r reflect.Type, index []int) {
211 for i := 0; i < r.NumField(); i++ {
212 idx := append(index, i)
213 field := t.FieldByIndex(idx)
214 if !field.IsExported() {
215 continue
216 }
217 // If this is an embedded struct, traverse one level deeper to extract
218 // the field and get their encoders as well.
219 if field.Anonymous {
220 collectEncoderFields(field.Type, idx)
221 continue
222 }
223 // If json tag is not present, then we skip, which is intentionally
224 // different behavior from the stdlib.
225 ptag, ok := parseJSONStructTag(field)
226 if !ok {
227 continue
228 }
229 // We only want to support unexported field if they're tagged with
230 // `extras` because that field shouldn't be part of the public API. We
231 // also want to only keep the top level extras
232 if ptag.extras && len(index) == 0 {
233 extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
234 continue
235 }
236 if ptag.name == "-" {
237 continue
238 }
239
240 dateFormat, ok := parseFormatStructTag(field)
241 oldFormat := e.dateFormat
242 if ok {
243 switch dateFormat {
244 case "date-time":
245 e.dateFormat = time.RFC3339
246 case "date":
247 e.dateFormat = "2006-01-02"
248 }
249 }
250 encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
251 e.dateFormat = oldFormat
252 }
253 }
254 collectEncoderFields(t, []int{})
255
256 // Ensure deterministic output by sorting by lexicographic order
257 sort.Slice(encoderFields, func(i, j int) bool {
258 return encoderFields[i].tag.name < encoderFields[j].tag.name
259 })
260
261 return func(value reflect.Value) (json []byte, err error) {
262 json = []byte("{}")
263
264 for _, ef := range encoderFields {
265 field := value.FieldByIndex(ef.idx)
266 encoded, err := ef.fn(field)
267 if err != nil {
268 return nil, err
269 }
270 if encoded == nil {
271 continue
272 }
273 json, err = sjson.SetRawBytes(json, ef.tag.name, encoded)
274 if err != nil {
275 return nil, err
276 }
277 }
278
279 if extraEncoder != nil {
280 json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx))
281 if err != nil {
282 return nil, err
283 }
284 }
285 return
286 }
287}
288
289func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
290 f, _ := t.FieldByName("Value")
291 enc := e.typeEncoder(f.Type)
292
293 return func(value reflect.Value) (json []byte, err error) {
294 present := value.FieldByName("Present")
295 if !present.Bool() {
296 return nil, nil
297 }
298 null := value.FieldByName("Null")
299 if null.Bool() {
300 return []byte("null"), nil
301 }
302 raw := value.FieldByName("Raw")
303 if !raw.IsNil() {
304 return e.typeEncoder(raw.Type())(raw)
305 }
306 return enc(value.FieldByName("Value"))
307 }
308}
309
310func (e *encoder) newTimeTypeEncoder() encoderFunc {
311 format := e.dateFormat
312 return func(value reflect.Value) (json []byte, err error) {
313 return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil
314 }
315}
316
317func (e encoder) newInterfaceEncoder() encoderFunc {
318 return func(value reflect.Value) ([]byte, error) {
319 value = value.Elem()
320 if !value.IsValid() {
321 return nil, nil
322 }
323 return e.typeEncoder(value.Type())(value)
324 }
325}
326
327// Given a []byte of json (may either be an empty object or an object that already contains entries)
328// encode all of the entries in the map to the json byte array.
329func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) {
330 type mapPair struct {
331 key []byte
332 value reflect.Value
333 }
334
335 pairs := []mapPair{}
336 keyEncoder := e.typeEncoder(v.Type().Key())
337
338 iter := v.MapRange()
339 for iter.Next() {
340 var encodedKeyString string
341 if iter.Key().Type().Kind() == reflect.String {
342 encodedKeyString = iter.Key().String()
343 } else {
344 var err error
345 encodedKeyBytes, err := keyEncoder(iter.Key())
346 if err != nil {
347 return nil, err
348 }
349 encodedKeyString = string(encodedKeyBytes)
350 }
351 encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString))
352 pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()})
353 }
354
355 // Ensure deterministic output
356 sort.Slice(pairs, func(i, j int) bool {
357 return bytes.Compare(pairs[i].key, pairs[j].key) < 0
358 })
359
360 elementEncoder := e.typeEncoder(v.Type().Elem())
361 for _, p := range pairs {
362 encodedValue, err := elementEncoder(p.value)
363 if err != nil {
364 return nil, err
365 }
366 if len(encodedValue) == 0 {
367 continue
368 }
369 json, err = sjson.SetRawBytes(json, string(p.key), encodedValue)
370 if err != nil {
371 return nil, err
372 }
373 }
374
375 return json, nil
376}
377
378func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
379 return func(value reflect.Value) ([]byte, error) {
380 json := []byte("{}")
381 var err error
382 json, err = e.encodeMapEntries(json, value)
383 if err != nil {
384 return nil, err
385 }
386 return json, nil
387 }
388}
389
390// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have
391// special characters that sjson interprets as a path.
392var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*")