encoder.go

  1package apiform
  2
  3import (
  4	"fmt"
  5	"io"
  6	"mime/multipart"
  7	"net/textproto"
  8	"path"
  9	"reflect"
 10	"sort"
 11	"strconv"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/openai/openai-go/packages/param"
 17)
 18
 19var encoders sync.Map // map[encoderEntry]encoderFunc
 20
 21func Marshal(value any, writer *multipart.Writer) error {
 22	e := &encoder{
 23		dateFormat: time.RFC3339,
 24		arrayFmt:   "brackets",
 25	}
 26	return e.marshal(value, writer)
 27}
 28
 29func MarshalRoot(value any, writer *multipart.Writer) error {
 30	e := &encoder{
 31		root:       true,
 32		dateFormat: time.RFC3339,
 33		arrayFmt:   "brackets",
 34	}
 35	return e.marshal(value, writer)
 36}
 37
 38func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error {
 39	e := &encoder{
 40		arrayFmt:   arrayFormat,
 41		dateFormat: time.RFC3339,
 42	}
 43	return e.marshal(value, writer)
 44}
 45
 46type encoder struct {
 47	arrayFmt   string
 48	dateFormat string
 49	root       bool
 50}
 51
 52type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error
 53
 54type encoderField struct {
 55	tag parsedStructTag
 56	fn  encoderFunc
 57	idx []int
 58}
 59
 60type encoderEntry struct {
 61	reflect.Type
 62	dateFormat string
 63	root       bool
 64}
 65
 66func (e *encoder) marshal(value any, writer *multipart.Writer) error {
 67	val := reflect.ValueOf(value)
 68	if !val.IsValid() {
 69		return nil
 70	}
 71	typ := val.Type()
 72	enc := e.typeEncoder(typ)
 73	return enc("", val, writer)
 74}
 75
 76func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 77	entry := encoderEntry{
 78		Type:       t,
 79		dateFormat: e.dateFormat,
 80		root:       e.root,
 81	}
 82
 83	if fi, ok := encoders.Load(entry); ok {
 84		return fi.(encoderFunc)
 85	}
 86
 87	// To deal with recursive types, populate the map with an
 88	// indirect func before we build it. This type waits on the
 89	// real func (f) to be ready and then calls it. This indirect
 90	// func is only used for recursive types.
 91	var (
 92		wg sync.WaitGroup
 93		f  encoderFunc
 94	)
 95	wg.Add(1)
 96	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error {
 97		wg.Wait()
 98		return f(key, v, writer)
 99	}))
100	if loaded {
101		return fi.(encoderFunc)
102	}
103
104	// Compute the real encoder and replace the indirect func with it.
105	f = e.newTypeEncoder(t)
106	wg.Done()
107	encoders.Store(entry, f)
108	return f
109}
110
111func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
112	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
113		return e.newTimeTypeEncoder()
114	}
115	if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
116		return e.newReaderTypeEncoder()
117	}
118	e.root = false
119	switch t.Kind() {
120	case reflect.Pointer:
121		inner := t.Elem()
122
123		innerEncoder := e.typeEncoder(inner)
124		return func(key string, v reflect.Value, writer *multipart.Writer) error {
125			if !v.IsValid() || v.IsNil() {
126				return nil
127			}
128			return innerEncoder(key, v.Elem(), writer)
129		}
130	case reflect.Struct:
131		return e.newStructTypeEncoder(t)
132	case reflect.Slice, reflect.Array:
133		return e.newArrayTypeEncoder(t)
134	case reflect.Map:
135		return e.newMapEncoder(t)
136	case reflect.Interface:
137		return e.newInterfaceEncoder()
138	default:
139		return e.newPrimitiveTypeEncoder(t)
140	}
141}
142
143func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
144	switch t.Kind() {
145	// Note that we could use `gjson` to encode these types but it would complicate our
146	// code more and this current code shouldn't cause any issues
147	case reflect.String:
148		return func(key string, v reflect.Value, writer *multipart.Writer) error {
149			return writer.WriteField(key, v.String())
150		}
151	case reflect.Bool:
152		return func(key string, v reflect.Value, writer *multipart.Writer) error {
153			if v.Bool() {
154				return writer.WriteField(key, "true")
155			}
156			return writer.WriteField(key, "false")
157		}
158	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
159		return func(key string, v reflect.Value, writer *multipart.Writer) error {
160			return writer.WriteField(key, strconv.FormatInt(v.Int(), 10))
161		}
162	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
163		return func(key string, v reflect.Value, writer *multipart.Writer) error {
164			return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10))
165		}
166	case reflect.Float32:
167		return func(key string, v reflect.Value, writer *multipart.Writer) error {
168			return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32))
169		}
170	case reflect.Float64:
171		return func(key string, v reflect.Value, writer *multipart.Writer) error {
172			return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
173		}
174	default:
175		return func(key string, v reflect.Value, writer *multipart.Writer) error {
176			return fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
177		}
178	}
179}
180
181func arrayKeyEncoder(arrayFmt string) func(string, int) string {
182	var keyFn func(string, int) string
183	switch arrayFmt {
184	case "comma", "repeat":
185		keyFn = func(k string, _ int) string { return k }
186	case "brackets":
187		keyFn = func(key string, _ int) string { return key + "[]" }
188	case "indices:dots":
189		keyFn = func(k string, i int) string {
190			if k == "" {
191				return strconv.Itoa(i)
192			}
193			return k + "." + strconv.Itoa(i)
194		}
195	case "indices:brackets":
196		keyFn = func(k string, i int) string {
197			if k == "" {
198				return strconv.Itoa(i)
199			}
200			return k + "[" + strconv.Itoa(i) + "]"
201		}
202	}
203	return keyFn
204}
205
206func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
207	itemEncoder := e.typeEncoder(t.Elem())
208	keyFn := arrayKeyEncoder(e.arrayFmt)
209	return func(key string, v reflect.Value, writer *multipart.Writer) error {
210		if keyFn == nil {
211			return fmt.Errorf("apiform: unsupported array format")
212		}
213		for i := 0; i < v.Len(); i++ {
214			err := itemEncoder(keyFn(key, i), v.Index(i), writer)
215			if err != nil {
216				return err
217			}
218		}
219		return nil
220	}
221}
222
223func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
224	if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
225		return e.newRichFieldTypeEncoder(t)
226	}
227
228	for i := 0; i < t.NumField(); i++ {
229		if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
230			return e.newStructUnionTypeEncoder(t)
231		}
232	}
233
234	encoderFields := []encoderField{}
235	extraEncoder := (*encoderField)(nil)
236
237	// This helper allows us to recursively collect field encoders into a flat
238	// array. The parameter `index` keeps track of the access patterns necessary
239	// to get to some field.
240	var collectEncoderFields func(r reflect.Type, index []int)
241	collectEncoderFields = func(r reflect.Type, index []int) {
242		for i := 0; i < r.NumField(); i++ {
243			idx := append(index, i)
244			field := t.FieldByIndex(idx)
245			if !field.IsExported() {
246				continue
247			}
248			// If this is an embedded struct, traverse one level deeper to extract
249			// the field and get their encoders as well.
250			if field.Anonymous {
251				collectEncoderFields(field.Type, idx)
252				continue
253			}
254			// If json tag is not present, then we skip, which is intentionally
255			// different behavior from the stdlib.
256			ptag, ok := parseFormStructTag(field)
257			if !ok {
258				continue
259			}
260			// We only want to support unexported field if they're tagged with
261			// `extras` because that field shouldn't be part of the public API. We
262			// also want to only keep the top level extras
263			if ptag.extras && len(index) == 0 {
264				extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
265				continue
266			}
267			if ptag.name == "-" || ptag.name == "" {
268				continue
269			}
270
271			dateFormat, ok := parseFormatStructTag(field)
272			oldFormat := e.dateFormat
273			if ok {
274				switch dateFormat {
275				case "date-time":
276					e.dateFormat = time.RFC3339
277				case "date":
278					e.dateFormat = "2006-01-02"
279				}
280			}
281
282			var encoderFn encoderFunc
283			if ptag.omitzero {
284				typeEncoderFn := e.typeEncoder(field.Type)
285				encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error {
286					if value.IsZero() {
287						return nil
288					}
289					return typeEncoderFn(key, value, writer)
290				}
291			} else {
292				encoderFn = e.typeEncoder(field.Type)
293			}
294			encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
295			e.dateFormat = oldFormat
296		}
297	}
298	collectEncoderFields(t, []int{})
299
300	// Ensure deterministic output by sorting by lexicographic order
301	sort.Slice(encoderFields, func(i, j int) bool {
302		return encoderFields[i].tag.name < encoderFields[j].tag.name
303	})
304
305	return func(key string, value reflect.Value, writer *multipart.Writer) error {
306		if key != "" {
307			key = key + "."
308		}
309
310		for _, ef := range encoderFields {
311			field := value.FieldByIndex(ef.idx)
312			err := ef.fn(key+ef.tag.name, field, writer)
313			if err != nil {
314				return err
315			}
316		}
317
318		if extraEncoder != nil {
319			err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer)
320			if err != nil {
321				return err
322			}
323		}
324
325		return nil
326	}
327}
328
329var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
330
331func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
332	var fieldEncoders []encoderFunc
333	for i := 0; i < t.NumField(); i++ {
334		field := t.Field(i)
335		if field.Type == paramUnionType && field.Anonymous {
336			fieldEncoders = append(fieldEncoders, nil)
337			continue
338		}
339		fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
340	}
341
342	return func(key string, value reflect.Value, writer *multipart.Writer) error {
343		for i := 0; i < t.NumField(); i++ {
344			if value.Field(i).Type() == paramUnionType {
345				continue
346			}
347			if !value.Field(i).IsZero() {
348				return fieldEncoders[i](key, value.Field(i), writer)
349			}
350		}
351		return fmt.Errorf("apiform: union %s has no field set", t.String())
352	}
353}
354
355func (e *encoder) newTimeTypeEncoder() encoderFunc {
356	format := e.dateFormat
357	return func(key string, value reflect.Value, writer *multipart.Writer) error {
358		return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format))
359	}
360}
361
362func (e encoder) newInterfaceEncoder() encoderFunc {
363	return func(key string, value reflect.Value, writer *multipart.Writer) error {
364		value = value.Elem()
365		if !value.IsValid() {
366			return nil
367		}
368		return e.typeEncoder(value.Type())(key, value, writer)
369	}
370}
371
372var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
373
374func escapeQuotes(s string) string {
375	return quoteEscaper.Replace(s)
376}
377
378func (e *encoder) newReaderTypeEncoder() encoderFunc {
379	return func(key string, value reflect.Value, writer *multipart.Writer) error {
380		reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
381		if !ok {
382			return nil
383		}
384		filename := "anonymous_file"
385		contentType := "application/octet-stream"
386		if named, ok := reader.(interface{ Filename() string }); ok {
387			filename = named.Filename()
388		} else if named, ok := reader.(interface{ Name() string }); ok {
389			filename = path.Base(named.Name())
390		}
391		if typed, ok := reader.(interface{ ContentType() string }); ok {
392			contentType = typed.ContentType()
393		}
394
395		// Below is taken almost 1-for-1 from [multipart.CreateFormFile]
396		h := make(textproto.MIMEHeader)
397		h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename)))
398		h.Set("Content-Type", contentType)
399		filewriter, err := writer.CreatePart(h)
400		if err != nil {
401			return err
402		}
403		_, err = io.Copy(filewriter, reader)
404		return err
405	}
406}
407
408// Given a []byte of json (may either be an empty object or an object that already contains entries)
409// encode all of the entries in the map to the json byte array.
410func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
411	type mapPair struct {
412		key   string
413		value reflect.Value
414	}
415
416	if key != "" {
417		key = key + "."
418	}
419
420	pairs := []mapPair{}
421
422	iter := v.MapRange()
423	for iter.Next() {
424		if iter.Key().Type().Kind() == reflect.String {
425			pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()})
426		} else {
427			return fmt.Errorf("cannot encode a map with a non string key")
428		}
429	}
430
431	// Ensure deterministic output
432	sort.Slice(pairs, func(i, j int) bool {
433		return pairs[i].key < pairs[j].key
434	})
435
436	elementEncoder := e.typeEncoder(v.Type().Elem())
437	for _, p := range pairs {
438		err := elementEncoder(key+string(p.key), p.value, writer)
439		if err != nil {
440			return err
441		}
442	}
443
444	return nil
445}
446
447func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
448	return func(key string, value reflect.Value, writer *multipart.Writer) error {
449		return e.encodeMapEntries(key, value, writer)
450	}
451}
452
453func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) {
454	for k, v := range extras {
455		str, ok := v.(string)
456		if !ok {
457			break
458		}
459		err = writer.WriteField(k, str)
460		if err != nil {
461			break
462		}
463	}
464	return
465}