1package apijson
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "sort"
9 "strconv"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/tidwall/sjson"
15
16 "github.com/openai/openai-go/internal/param"
17)
18
19var encoders sync.Map // map[encoderEntry]encoderFunc
20
21func Marshal(value interface{}) ([]byte, error) {
22 e := &encoder{dateFormat: time.RFC3339}
23 return e.marshal(value)
24}
25
26func MarshalRoot(value interface{}) ([]byte, error) {
27 e := &encoder{root: true, dateFormat: time.RFC3339}
28 return e.marshal(value)
29}
30
31type encoder struct {
32 dateFormat string
33 root bool
34}
35
36type encoderFunc func(value reflect.Value) ([]byte, error)
37
38type encoderField struct {
39 tag parsedStructTag
40 fn encoderFunc
41 idx []int
42}
43
44type encoderEntry struct {
45 reflect.Type
46 dateFormat string
47 root bool
48}
49
50func (e *encoder) marshal(value interface{}) ([]byte, error) {
51 val := reflect.ValueOf(value)
52 if !val.IsValid() {
53 return nil, nil
54 }
55 typ := val.Type()
56 enc := e.typeEncoder(typ)
57 return enc(val)
58}
59
60func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
61 entry := encoderEntry{
62 Type: t,
63 dateFormat: e.dateFormat,
64 root: e.root,
65 }
66
67 if fi, ok := encoders.Load(entry); ok {
68 return fi.(encoderFunc)
69 }
70
71 // To deal with recursive types, populate the map with an
72 // indirect func before we build it. This type waits on the
73 // real func (f) to be ready and then calls it. This indirect
74 // func is only used for recursive types.
75 var (
76 wg sync.WaitGroup
77 f encoderFunc
78 )
79 wg.Add(1)
80 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) {
81 wg.Wait()
82 return f(v)
83 }))
84 if loaded {
85 return fi.(encoderFunc)
86 }
87
88 // Compute the real encoder and replace the indirect func with it.
89 f = e.newTypeEncoder(t)
90 wg.Done()
91 encoders.Store(entry, f)
92 return f
93}
94
95func marshalerEncoder(v reflect.Value) ([]byte, error) {
96 return v.Interface().(json.Marshaler).MarshalJSON()
97}
98
99func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) {
100 return v.Addr().Interface().(json.Marshaler).MarshalJSON()
101}
102
103func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
104 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
105 return e.newTimeTypeEncoder()
106 }
107 if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
108 return marshalerEncoder
109 }
110 if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
111 return indirectMarshalerEncoder
112 }
113 e.root = false
114 switch t.Kind() {
115 case reflect.Pointer:
116 inner := t.Elem()
117
118 innerEncoder := e.typeEncoder(inner)
119 return func(v reflect.Value) ([]byte, error) {
120 if !v.IsValid() || v.IsNil() {
121 return nil, nil
122 }
123 return innerEncoder(v.Elem())
124 }
125 case reflect.Struct:
126 return e.newStructTypeEncoder(t)
127 case reflect.Array:
128 fallthrough
129 case reflect.Slice:
130 return e.newArrayTypeEncoder(t)
131 case reflect.Map:
132 return e.newMapEncoder(t)
133 case reflect.Interface:
134 return e.newInterfaceEncoder()
135 default:
136 return e.newPrimitiveTypeEncoder(t)
137 }
138}
139
140func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
141 switch t.Kind() {
142 // Note that we could use `gjson` to encode these types but it would complicate our
143 // code more and this current code shouldn't cause any issues
144 case reflect.String:
145 return func(v reflect.Value) ([]byte, error) {
146 return json.Marshal(v.Interface())
147 }
148 case reflect.Bool:
149 return func(v reflect.Value) ([]byte, error) {
150 if v.Bool() {
151 return []byte("true"), nil
152 }
153 return []byte("false"), nil
154 }
155 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
156 return func(v reflect.Value) ([]byte, error) {
157 return []byte(strconv.FormatInt(v.Int(), 10)), nil
158 }
159 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
160 return func(v reflect.Value) ([]byte, error) {
161 return []byte(strconv.FormatUint(v.Uint(), 10)), nil
162 }
163 case reflect.Float32:
164 return func(v reflect.Value) ([]byte, error) {
165 return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil
166 }
167 case reflect.Float64:
168 return func(v reflect.Value) ([]byte, error) {
169 return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil
170 }
171 default:
172 return func(v reflect.Value) ([]byte, error) {
173 return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
174 }
175 }
176}
177
178func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
179 itemEncoder := e.typeEncoder(t.Elem())
180
181 return func(value reflect.Value) ([]byte, error) {
182 json := []byte("[]")
183 for i := 0; i < value.Len(); i++ {
184 var value, err = itemEncoder(value.Index(i))
185 if err != nil {
186 return nil, err
187 }
188 if value == nil {
189 // Assume that empty items should be inserted as `null` so that the output array
190 // will be the same length as the input array
191 value = []byte("null")
192 }
193
194 json, err = sjson.SetRawBytes(json, "-1", value)
195 if err != nil {
196 return nil, err
197 }
198 }
199
200 return json, nil
201 }
202}
203
204func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
205 if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) {
206 return e.newFieldTypeEncoder(t)
207 }
208
209 encoderFields := []encoderField{}
210 extraEncoder := (*encoderField)(nil)
211
212 // This helper allows us to recursively collect field encoders into a flat
213 // array. The parameter `index` keeps track of the access patterns necessary
214 // to get to some field.
215 var collectEncoderFields func(r reflect.Type, index []int)
216 collectEncoderFields = func(r reflect.Type, index []int) {
217 for i := 0; i < r.NumField(); i++ {
218 idx := append(index, i)
219 field := t.FieldByIndex(idx)
220 if !field.IsExported() {
221 continue
222 }
223 // If this is an embedded struct, traverse one level deeper to extract
224 // the field and get their encoders as well.
225 if field.Anonymous {
226 collectEncoderFields(field.Type, idx)
227 continue
228 }
229 // If json tag is not present, then we skip, which is intentionally
230 // different behavior from the stdlib.
231 ptag, ok := parseJSONStructTag(field)
232 if !ok {
233 continue
234 }
235 // We only want to support unexported field if they're tagged with
236 // `extras` because that field shouldn't be part of the public API. We
237 // also want to only keep the top level extras
238 if ptag.extras && len(index) == 0 {
239 extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
240 continue
241 }
242 if ptag.name == "-" {
243 continue
244 }
245
246 dateFormat, ok := parseFormatStructTag(field)
247 oldFormat := e.dateFormat
248 if ok {
249 switch dateFormat {
250 case "date-time":
251 e.dateFormat = time.RFC3339
252 case "date":
253 e.dateFormat = "2006-01-02"
254 }
255 }
256 encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx})
257 e.dateFormat = oldFormat
258 }
259 }
260 collectEncoderFields(t, []int{})
261
262 // Ensure deterministic output by sorting by lexicographic order
263 sort.Slice(encoderFields, func(i, j int) bool {
264 return encoderFields[i].tag.name < encoderFields[j].tag.name
265 })
266
267 return func(value reflect.Value) (json []byte, err error) {
268 json = []byte("{}")
269
270 for _, ef := range encoderFields {
271 field := value.FieldByIndex(ef.idx)
272 encoded, err := ef.fn(field)
273 if err != nil {
274 return nil, err
275 }
276 if encoded == nil {
277 continue
278 }
279 json, err = sjson.SetRawBytes(json, ef.tag.name, encoded)
280 if err != nil {
281 return nil, err
282 }
283 }
284
285 if extraEncoder != nil {
286 json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx))
287 if err != nil {
288 return nil, err
289 }
290 }
291 return
292 }
293}
294
295func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
296 f, _ := t.FieldByName("Value")
297 enc := e.typeEncoder(f.Type)
298
299 return func(value reflect.Value) (json []byte, err error) {
300 present := value.FieldByName("Present")
301 if !present.Bool() {
302 return nil, nil
303 }
304 null := value.FieldByName("Null")
305 if null.Bool() {
306 return []byte("null"), nil
307 }
308 raw := value.FieldByName("Raw")
309 if !raw.IsNil() {
310 return e.typeEncoder(raw.Type())(raw)
311 }
312 return enc(value.FieldByName("Value"))
313 }
314}
315
316func (e *encoder) newTimeTypeEncoder() encoderFunc {
317 format := e.dateFormat
318 return func(value reflect.Value) (json []byte, err error) {
319 return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil
320 }
321}
322
323func (e encoder) newInterfaceEncoder() encoderFunc {
324 return func(value reflect.Value) ([]byte, error) {
325 value = value.Elem()
326 if !value.IsValid() {
327 return nil, nil
328 }
329 return e.typeEncoder(value.Type())(value)
330 }
331}
332
333// Given a []byte of json (may either be an empty object or an object that already contains entries)
334// encode all of the entries in the map to the json byte array.
335func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) {
336 type mapPair struct {
337 key []byte
338 value reflect.Value
339 }
340
341 pairs := []mapPair{}
342 keyEncoder := e.typeEncoder(v.Type().Key())
343
344 iter := v.MapRange()
345 for iter.Next() {
346 var encodedKeyString string
347 if iter.Key().Type().Kind() == reflect.String {
348 encodedKeyString = iter.Key().String()
349 } else {
350 var err error
351 encodedKeyBytes, err := keyEncoder(iter.Key())
352 if err != nil {
353 return nil, err
354 }
355 encodedKeyString = string(encodedKeyBytes)
356 }
357 encodedKey := []byte(sjsonReplacer.Replace(encodedKeyString))
358 pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()})
359 }
360
361 // Ensure deterministic output
362 sort.Slice(pairs, func(i, j int) bool {
363 return bytes.Compare(pairs[i].key, pairs[j].key) < 0
364 })
365
366 elementEncoder := e.typeEncoder(v.Type().Elem())
367 for _, p := range pairs {
368 encodedValue, err := elementEncoder(p.value)
369 if err != nil {
370 return nil, err
371 }
372 if len(encodedValue) == 0 {
373 continue
374 }
375 json, err = sjson.SetRawBytes(json, string(p.key), encodedValue)
376 if err != nil {
377 return nil, err
378 }
379 }
380
381 return json, nil
382}
383
384func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
385 return func(value reflect.Value) ([]byte, error) {
386 json := []byte("{}")
387 var err error
388 json, err = e.encodeMapEntries(json, value)
389 if err != nil {
390 return nil, err
391 }
392 return json, nil
393 }
394}
395
396// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have
397// special characters that sjson interprets as a path.
398var sjsonReplacer *strings.Replacer = strings.NewReplacer(".", "\\.", ":", "\\:", "*", "\\*")