1package apiform
2
3import (
4 "fmt"
5 "io"
6 "mime/multipart"
7 "net/textproto"
8 "path"
9 "reflect"
10 "sort"
11 "strconv"
12 "strings"
13 "sync"
14 "time"
15
16 internalparam "github.com/openai/openai-go/internal/param"
17 "github.com/openai/openai-go/packages/param"
18)
19
20var encoders sync.Map // map[encoderEntry]encoderFunc
21
22func Marshal(value interface{}, writer *multipart.Writer) error {
23 e := &encoder{dateFormat: time.RFC3339}
24 return e.marshal(value, writer)
25}
26
27func MarshalRoot(value interface{}, writer *multipart.Writer) error {
28 e := &encoder{root: true, dateFormat: time.RFC3339}
29 return e.marshal(value, writer)
30}
31
32type encoder struct {
33 dateFormat string
34 root bool
35}
36
37type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error
38
39type encoderField struct {
40 tag parsedStructTag
41 fn encoderFunc
42 idx []int
43}
44
45type encoderEntry struct {
46 reflect.Type
47 dateFormat string
48 root bool
49}
50
51func (e *encoder) marshal(value interface{}, writer *multipart.Writer) error {
52 val := reflect.ValueOf(value)
53 if !val.IsValid() {
54 return nil
55 }
56 typ := val.Type()
57 enc := e.typeEncoder(typ)
58 return enc("", val, writer)
59}
60
61func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
62 entry := encoderEntry{
63 Type: t,
64 dateFormat: e.dateFormat,
65 root: e.root,
66 }
67
68 if fi, ok := encoders.Load(entry); ok {
69 return fi.(encoderFunc)
70 }
71
72 // To deal with recursive types, populate the map with an
73 // indirect func before we build it. This type waits on the
74 // real func (f) to be ready and then calls it. This indirect
75 // func is only used for recursive types.
76 var (
77 wg sync.WaitGroup
78 f encoderFunc
79 )
80 wg.Add(1)
81 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error {
82 wg.Wait()
83 return f(key, v, writer)
84 }))
85 if loaded {
86 return fi.(encoderFunc)
87 }
88
89 // Compute the real encoder and replace the indirect func with it.
90 f = e.newTypeEncoder(t)
91 wg.Done()
92 encoders.Store(entry, f)
93 return f
94}
95
96func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
97 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
98 return e.newTimeTypeEncoder()
99 }
100 if t.ConvertibleTo(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
101 return e.newReaderTypeEncoder()
102 }
103 e.root = false
104 switch t.Kind() {
105 case reflect.Pointer:
106 inner := t.Elem()
107
108 innerEncoder := e.typeEncoder(inner)
109 return func(key string, v reflect.Value, writer *multipart.Writer) error {
110 if !v.IsValid() || v.IsNil() {
111 return nil
112 }
113 return innerEncoder(key, v.Elem(), writer)
114 }
115 case reflect.Struct:
116 return e.newStructTypeEncoder(t)
117 case reflect.Slice, reflect.Array:
118 return e.newArrayTypeEncoder(t)
119 case reflect.Map:
120 return e.newMapEncoder(t)
121 case reflect.Interface:
122 return e.newInterfaceEncoder()
123 default:
124 return e.newPrimitiveTypeEncoder(t)
125 }
126}
127
128func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
129 switch t.Kind() {
130 // Note that we could use `gjson` to encode these types but it would complicate our
131 // code more and this current code shouldn't cause any issues
132 case reflect.String:
133 return func(key string, v reflect.Value, writer *multipart.Writer) error {
134 return writer.WriteField(key, v.String())
135 }
136 case reflect.Bool:
137 return func(key string, v reflect.Value, writer *multipart.Writer) error {
138 if v.Bool() {
139 return writer.WriteField(key, "true")
140 }
141 return writer.WriteField(key, "false")
142 }
143 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
144 return func(key string, v reflect.Value, writer *multipart.Writer) error {
145 return writer.WriteField(key, strconv.FormatInt(v.Int(), 10))
146 }
147 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
148 return func(key string, v reflect.Value, writer *multipart.Writer) error {
149 return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10))
150 }
151 case reflect.Float32:
152 return func(key string, v reflect.Value, writer *multipart.Writer) error {
153 return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32))
154 }
155 case reflect.Float64:
156 return func(key string, v reflect.Value, writer *multipart.Writer) error {
157 return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
158 }
159 default:
160 return func(key string, v reflect.Value, writer *multipart.Writer) error {
161 return fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
162 }
163 }
164}
165
166func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
167 itemEncoder := e.typeEncoder(t.Elem())
168
169 return func(key string, v reflect.Value, writer *multipart.Writer) error {
170 if key != "" {
171 key = key + "."
172 }
173 for i := 0; i < v.Len(); i++ {
174 err := itemEncoder(key+strconv.Itoa(i), v.Index(i), writer)
175 if err != nil {
176 return err
177 }
178 }
179 return nil
180 }
181}
182
183func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
184 if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
185 return e.newFieldTypeEncoder(t)
186 }
187
188 if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
189 return e.newRichFieldTypeEncoder(t, idx)
190 }
191
192 encoderFields := []encoderField{}
193 extraEncoder := (*encoderField)(nil)
194
195 // This helper allows us to recursively collect field encoders into a flat
196 // array. The parameter `index` keeps track of the access patterns necessary
197 // to get to some field.
198 var collectEncoderFields func(r reflect.Type, index []int)
199 collectEncoderFields = func(r reflect.Type, index []int) {
200 for i := 0; i < r.NumField(); i++ {
201 idx := append(index, i)
202 field := t.FieldByIndex(idx)
203 if !field.IsExported() {
204 continue
205 }
206 // If this is an embedded struct, traverse one level deeper to extract
207 // the field and get their encoders as well.
208 if field.Anonymous {
209 collectEncoderFields(field.Type, idx)
210 continue
211 }
212 // If json tag is not present, then we skip, which is intentionally
213 // different behavior from the stdlib.
214 ptag, ok := parseFormStructTag(field)
215 if !ok {
216 continue
217 }
218 // We only want to support unexported field if they're tagged with
219 // `extras` because that field shouldn't be part of the public API. We
220 // also want to only keep the top level extras
221 if ptag.extras && len(index) == 0 {
222 extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
223 continue
224 }
225 if ptag.name == "-" || ptag.name == "" {
226 continue
227 }
228
229 dateFormat, ok := parseFormatStructTag(field)
230 oldFormat := e.dateFormat
231 if ok {
232 switch dateFormat {
233 case "date-time":
234 e.dateFormat = time.RFC3339
235 case "date":
236 e.dateFormat = "2006-01-02"
237 }
238 }
239
240 var encoderFn encoderFunc
241 if ptag.omitzero {
242 typeEncoderFn := e.typeEncoder(field.Type)
243 encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error {
244 if value.IsZero() {
245 return nil
246 }
247 return typeEncoderFn(key, value, writer)
248 }
249 } else {
250 encoderFn = e.typeEncoder(field.Type)
251 }
252 encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
253 e.dateFormat = oldFormat
254 }
255 }
256 collectEncoderFields(t, []int{})
257
258 // Ensure deterministic output by sorting by lexicographic order
259 sort.Slice(encoderFields, func(i, j int) bool {
260 return encoderFields[i].tag.name < encoderFields[j].tag.name
261 })
262
263 return func(key string, value reflect.Value, writer *multipart.Writer) error {
264 if key != "" {
265 key = key + "."
266 }
267
268 for _, ef := range encoderFields {
269 field := value.FieldByIndex(ef.idx)
270 err := ef.fn(key+ef.tag.name, field, writer)
271 if err != nil {
272 return err
273 }
274 }
275
276 if extraEncoder != nil {
277 err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer)
278 if err != nil {
279 return err
280 }
281 }
282
283 return nil
284 }
285}
286
287func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
288 f, _ := t.FieldByName("Value")
289 enc := e.typeEncoder(f.Type)
290
291 return func(key string, value reflect.Value, writer *multipart.Writer) error {
292 present := value.FieldByName("Present")
293 if !present.Bool() {
294 return nil
295 }
296 null := value.FieldByName("Null")
297 if null.Bool() {
298 return nil
299 }
300 raw := value.FieldByName("Raw")
301 if !raw.IsNil() {
302 return e.typeEncoder(raw.Type())(key, raw, writer)
303 }
304 return enc(key, value.FieldByName("Value"), writer)
305 }
306}
307
308func (e *encoder) newTimeTypeEncoder() encoderFunc {
309 format := e.dateFormat
310 return func(key string, value reflect.Value, writer *multipart.Writer) error {
311 return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format))
312 }
313}
314
315func (e encoder) newInterfaceEncoder() encoderFunc {
316 return func(key string, value reflect.Value, writer *multipart.Writer) error {
317 value = value.Elem()
318 if !value.IsValid() {
319 return nil
320 }
321 return e.typeEncoder(value.Type())(key, value, writer)
322 }
323}
324
325var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
326
327func escapeQuotes(s string) string {
328 return quoteEscaper.Replace(s)
329}
330
331func (e *encoder) newReaderTypeEncoder() encoderFunc {
332 return func(key string, value reflect.Value, writer *multipart.Writer) error {
333 reader := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
334 filename := "anonymous_file"
335 contentType := "application/octet-stream"
336 if named, ok := reader.(interface{ Filename() string }); ok {
337 filename = named.Filename()
338 } else if named, ok := reader.(interface{ Name() string }); ok {
339 filename = path.Base(named.Name())
340 }
341 if typed, ok := reader.(interface{ ContentType() string }); ok {
342 contentType = typed.ContentType()
343 }
344
345 // Below is taken almost 1-for-1 from [multipart.CreateFormFile]
346 h := make(textproto.MIMEHeader)
347 h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename)))
348 h.Set("Content-Type", contentType)
349 filewriter, err := writer.CreatePart(h)
350 if err != nil {
351 return err
352 }
353 _, err = io.Copy(filewriter, reader)
354 return err
355 }
356}
357
358// Given a []byte of json (may either be an empty object or an object that already contains entries)
359// encode all of the entries in the map to the json byte array.
360func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
361 type mapPair struct {
362 key string
363 value reflect.Value
364 }
365
366 if key != "" {
367 key = key + "."
368 }
369
370 pairs := []mapPair{}
371
372 iter := v.MapRange()
373 for iter.Next() {
374 if iter.Key().Type().Kind() == reflect.String {
375 pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()})
376 } else {
377 return fmt.Errorf("cannot encode a map with a non string key")
378 }
379 }
380
381 // Ensure deterministic output
382 sort.Slice(pairs, func(i, j int) bool {
383 return pairs[i].key < pairs[j].key
384 })
385
386 elementEncoder := e.typeEncoder(v.Type().Elem())
387 for _, p := range pairs {
388 err := elementEncoder(key+string(p.key), p.value, writer)
389 if err != nil {
390 return err
391 }
392 }
393
394 return nil
395}
396
397func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
398 return func(key string, value reflect.Value, writer *multipart.Writer) error {
399 return e.encodeMapEntries(key, value, writer)
400 }
401}