1package apiform
2
3import (
4 "fmt"
5 "io"
6 "mime/multipart"
7 "net/textproto"
8 "path"
9 "reflect"
10 "sort"
11 "strconv"
12 "strings"
13 "sync"
14 "time"
15
16 "github.com/openai/openai-go/packages/param"
17)
18
19var encoders sync.Map // map[encoderEntry]encoderFunc
20
21func Marshal(value any, writer *multipart.Writer) error {
22 e := &encoder{
23 dateFormat: time.RFC3339,
24 arrayFmt: "brackets",
25 }
26 return e.marshal(value, writer)
27}
28
29func MarshalRoot(value any, writer *multipart.Writer) error {
30 e := &encoder{
31 root: true,
32 dateFormat: time.RFC3339,
33 arrayFmt: "brackets",
34 }
35 return e.marshal(value, writer)
36}
37
38func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error {
39 e := &encoder{
40 arrayFmt: arrayFormat,
41 dateFormat: time.RFC3339,
42 }
43 return e.marshal(value, writer)
44}
45
46type encoder struct {
47 arrayFmt string
48 dateFormat string
49 root bool
50}
51
52type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error
53
54type encoderField struct {
55 tag parsedStructTag
56 fn encoderFunc
57 idx []int
58}
59
60type encoderEntry struct {
61 reflect.Type
62 dateFormat string
63 root bool
64}
65
66func (e *encoder) marshal(value any, writer *multipart.Writer) error {
67 val := reflect.ValueOf(value)
68 if !val.IsValid() {
69 return nil
70 }
71 typ := val.Type()
72 enc := e.typeEncoder(typ)
73 return enc("", val, writer)
74}
75
76func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
77 entry := encoderEntry{
78 Type: t,
79 dateFormat: e.dateFormat,
80 root: e.root,
81 }
82
83 if fi, ok := encoders.Load(entry); ok {
84 return fi.(encoderFunc)
85 }
86
87 // To deal with recursive types, populate the map with an
88 // indirect func before we build it. This type waits on the
89 // real func (f) to be ready and then calls it. This indirect
90 // func is only used for recursive types.
91 var (
92 wg sync.WaitGroup
93 f encoderFunc
94 )
95 wg.Add(1)
96 fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error {
97 wg.Wait()
98 return f(key, v, writer)
99 }))
100 if loaded {
101 return fi.(encoderFunc)
102 }
103
104 // Compute the real encoder and replace the indirect func with it.
105 f = e.newTypeEncoder(t)
106 wg.Done()
107 encoders.Store(entry, f)
108 return f
109}
110
111func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
112 if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
113 return e.newTimeTypeEncoder()
114 }
115 if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
116 return e.newReaderTypeEncoder()
117 }
118 e.root = false
119 switch t.Kind() {
120 case reflect.Pointer:
121 inner := t.Elem()
122
123 innerEncoder := e.typeEncoder(inner)
124 return func(key string, v reflect.Value, writer *multipart.Writer) error {
125 if !v.IsValid() || v.IsNil() {
126 return nil
127 }
128 return innerEncoder(key, v.Elem(), writer)
129 }
130 case reflect.Struct:
131 return e.newStructTypeEncoder(t)
132 case reflect.Slice, reflect.Array:
133 return e.newArrayTypeEncoder(t)
134 case reflect.Map:
135 return e.newMapEncoder(t)
136 case reflect.Interface:
137 return e.newInterfaceEncoder()
138 default:
139 return e.newPrimitiveTypeEncoder(t)
140 }
141}
142
143func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
144 switch t.Kind() {
145 // Note that we could use `gjson` to encode these types but it would complicate our
146 // code more and this current code shouldn't cause any issues
147 case reflect.String:
148 return func(key string, v reflect.Value, writer *multipart.Writer) error {
149 return writer.WriteField(key, v.String())
150 }
151 case reflect.Bool:
152 return func(key string, v reflect.Value, writer *multipart.Writer) error {
153 if v.Bool() {
154 return writer.WriteField(key, "true")
155 }
156 return writer.WriteField(key, "false")
157 }
158 case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
159 return func(key string, v reflect.Value, writer *multipart.Writer) error {
160 return writer.WriteField(key, strconv.FormatInt(v.Int(), 10))
161 }
162 case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
163 return func(key string, v reflect.Value, writer *multipart.Writer) error {
164 return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10))
165 }
166 case reflect.Float32:
167 return func(key string, v reflect.Value, writer *multipart.Writer) error {
168 return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32))
169 }
170 case reflect.Float64:
171 return func(key string, v reflect.Value, writer *multipart.Writer) error {
172 return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
173 }
174 default:
175 return func(key string, v reflect.Value, writer *multipart.Writer) error {
176 return fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
177 }
178 }
179}
180
181func arrayKeyEncoder(arrayFmt string) func(string, int) string {
182 var keyFn func(string, int) string
183 switch arrayFmt {
184 case "comma", "repeat":
185 keyFn = func(k string, _ int) string { return k }
186 case "brackets":
187 keyFn = func(key string, _ int) string { return key + "[]" }
188 case "indices:dots":
189 keyFn = func(k string, i int) string {
190 if k == "" {
191 return strconv.Itoa(i)
192 }
193 return k + "." + strconv.Itoa(i)
194 }
195 case "indices:brackets":
196 keyFn = func(k string, i int) string {
197 if k == "" {
198 return strconv.Itoa(i)
199 }
200 return k + "[" + strconv.Itoa(i) + "]"
201 }
202 }
203 return keyFn
204}
205
206func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
207 itemEncoder := e.typeEncoder(t.Elem())
208 keyFn := arrayKeyEncoder(e.arrayFmt)
209 return func(key string, v reflect.Value, writer *multipart.Writer) error {
210 if keyFn == nil {
211 return fmt.Errorf("apiform: unsupported array format")
212 }
213 for i := 0; i < v.Len(); i++ {
214 err := itemEncoder(keyFn(key, i), v.Index(i), writer)
215 if err != nil {
216 return err
217 }
218 }
219 return nil
220 }
221}
222
223func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
224 if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
225 return e.newRichFieldTypeEncoder(t)
226 }
227
228 for i := 0; i < t.NumField(); i++ {
229 if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
230 return e.newStructUnionTypeEncoder(t)
231 }
232 }
233
234 encoderFields := []encoderField{}
235 extraEncoder := (*encoderField)(nil)
236
237 // This helper allows us to recursively collect field encoders into a flat
238 // array. The parameter `index` keeps track of the access patterns necessary
239 // to get to some field.
240 var collectEncoderFields func(r reflect.Type, index []int)
241 collectEncoderFields = func(r reflect.Type, index []int) {
242 for i := 0; i < r.NumField(); i++ {
243 idx := append(index, i)
244 field := t.FieldByIndex(idx)
245 if !field.IsExported() {
246 continue
247 }
248 // If this is an embedded struct, traverse one level deeper to extract
249 // the field and get their encoders as well.
250 if field.Anonymous {
251 collectEncoderFields(field.Type, idx)
252 continue
253 }
254 // If json tag is not present, then we skip, which is intentionally
255 // different behavior from the stdlib.
256 ptag, ok := parseFormStructTag(field)
257 if !ok {
258 continue
259 }
260 // We only want to support unexported field if they're tagged with
261 // `extras` because that field shouldn't be part of the public API. We
262 // also want to only keep the top level extras
263 if ptag.extras && len(index) == 0 {
264 extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
265 continue
266 }
267 if ptag.name == "-" || ptag.name == "" {
268 continue
269 }
270
271 dateFormat, ok := parseFormatStructTag(field)
272 oldFormat := e.dateFormat
273 if ok {
274 switch dateFormat {
275 case "date-time":
276 e.dateFormat = time.RFC3339
277 case "date":
278 e.dateFormat = "2006-01-02"
279 }
280 }
281
282 var encoderFn encoderFunc
283 if ptag.omitzero {
284 typeEncoderFn := e.typeEncoder(field.Type)
285 encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error {
286 if value.IsZero() {
287 return nil
288 }
289 return typeEncoderFn(key, value, writer)
290 }
291 } else {
292 encoderFn = e.typeEncoder(field.Type)
293 }
294 encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
295 e.dateFormat = oldFormat
296 }
297 }
298 collectEncoderFields(t, []int{})
299
300 // Ensure deterministic output by sorting by lexicographic order
301 sort.Slice(encoderFields, func(i, j int) bool {
302 return encoderFields[i].tag.name < encoderFields[j].tag.name
303 })
304
305 return func(key string, value reflect.Value, writer *multipart.Writer) error {
306 if key != "" {
307 key = key + "."
308 }
309
310 for _, ef := range encoderFields {
311 field := value.FieldByIndex(ef.idx)
312 err := ef.fn(key+ef.tag.name, field, writer)
313 if err != nil {
314 return err
315 }
316 }
317
318 if extraEncoder != nil {
319 err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer)
320 if err != nil {
321 return err
322 }
323 }
324
325 return nil
326 }
327}
328
329var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
330
331func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
332 var fieldEncoders []encoderFunc
333 for i := 0; i < t.NumField(); i++ {
334 field := t.Field(i)
335 if field.Type == paramUnionType && field.Anonymous {
336 fieldEncoders = append(fieldEncoders, nil)
337 continue
338 }
339 fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
340 }
341
342 return func(key string, value reflect.Value, writer *multipart.Writer) error {
343 for i := 0; i < t.NumField(); i++ {
344 if value.Field(i).Type() == paramUnionType {
345 continue
346 }
347 if !value.Field(i).IsZero() {
348 return fieldEncoders[i](key, value.Field(i), writer)
349 }
350 }
351 return fmt.Errorf("apiform: union %s has no field set", t.String())
352 }
353}
354
355func (e *encoder) newTimeTypeEncoder() encoderFunc {
356 format := e.dateFormat
357 return func(key string, value reflect.Value, writer *multipart.Writer) error {
358 return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format))
359 }
360}
361
362func (e encoder) newInterfaceEncoder() encoderFunc {
363 return func(key string, value reflect.Value, writer *multipart.Writer) error {
364 value = value.Elem()
365 if !value.IsValid() {
366 return nil
367 }
368 return e.typeEncoder(value.Type())(key, value, writer)
369 }
370}
371
372var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
373
374func escapeQuotes(s string) string {
375 return quoteEscaper.Replace(s)
376}
377
378func (e *encoder) newReaderTypeEncoder() encoderFunc {
379 return func(key string, value reflect.Value, writer *multipart.Writer) error {
380 reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
381 if !ok {
382 return nil
383 }
384 filename := "anonymous_file"
385 contentType := "application/octet-stream"
386 if named, ok := reader.(interface{ Filename() string }); ok {
387 filename = named.Filename()
388 } else if named, ok := reader.(interface{ Name() string }); ok {
389 filename = path.Base(named.Name())
390 }
391 if typed, ok := reader.(interface{ ContentType() string }); ok {
392 contentType = typed.ContentType()
393 }
394
395 // Below is taken almost 1-for-1 from [multipart.CreateFormFile]
396 h := make(textproto.MIMEHeader)
397 h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename)))
398 h.Set("Content-Type", contentType)
399 filewriter, err := writer.CreatePart(h)
400 if err != nil {
401 return err
402 }
403 _, err = io.Copy(filewriter, reader)
404 return err
405 }
406}
407
408// Given a []byte of json (may either be an empty object or an object that already contains entries)
409// encode all of the entries in the map to the json byte array.
410func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
411 type mapPair struct {
412 key string
413 value reflect.Value
414 }
415
416 if key != "" {
417 key = key + "."
418 }
419
420 pairs := []mapPair{}
421
422 iter := v.MapRange()
423 for iter.Next() {
424 if iter.Key().Type().Kind() == reflect.String {
425 pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()})
426 } else {
427 return fmt.Errorf("cannot encode a map with a non string key")
428 }
429 }
430
431 // Ensure deterministic output
432 sort.Slice(pairs, func(i, j int) bool {
433 return pairs[i].key < pairs[j].key
434 })
435
436 elementEncoder := e.typeEncoder(v.Type().Elem())
437 for _, p := range pairs {
438 err := elementEncoder(key+string(p.key), p.value, writer)
439 if err != nil {
440 return err
441 }
442 }
443
444 return nil
445}
446
447func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
448 return func(key string, value reflect.Value, writer *multipart.Writer) error {
449 return e.encodeMapEntries(key, value, writer)
450 }
451}
452
453func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) {
454 for k, v := range extras {
455 str, ok := v.(string)
456 if !ok {
457 break
458 }
459 err = writer.WriteField(k, str)
460 if err != nil {
461 break
462 }
463 }
464 return
465}