encoder.go

  1package apiform
  2
  3import (
  4	"fmt"
  5	"io"
  6	"mime/multipart"
  7	"net/textproto"
  8	"path"
  9	"reflect"
 10	"sort"
 11	"strconv"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	internalparam "github.com/openai/openai-go/internal/param"
 17	"github.com/openai/openai-go/packages/param"
 18)
 19
 20var encoders sync.Map // map[encoderEntry]encoderFunc
 21
 22func Marshal(value interface{}, writer *multipart.Writer) error {
 23	e := &encoder{dateFormat: time.RFC3339}
 24	return e.marshal(value, writer)
 25}
 26
 27func MarshalRoot(value interface{}, writer *multipart.Writer) error {
 28	e := &encoder{root: true, dateFormat: time.RFC3339}
 29	return e.marshal(value, writer)
 30}
 31
 32type encoder struct {
 33	dateFormat string
 34	root       bool
 35}
 36
 37type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error
 38
 39type encoderField struct {
 40	tag parsedStructTag
 41	fn  encoderFunc
 42	idx []int
 43}
 44
 45type encoderEntry struct {
 46	reflect.Type
 47	dateFormat string
 48	root       bool
 49}
 50
 51func (e *encoder) marshal(value interface{}, writer *multipart.Writer) error {
 52	val := reflect.ValueOf(value)
 53	if !val.IsValid() {
 54		return nil
 55	}
 56	typ := val.Type()
 57	enc := e.typeEncoder(typ)
 58	return enc("", val, writer)
 59}
 60
 61func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
 62	entry := encoderEntry{
 63		Type:       t,
 64		dateFormat: e.dateFormat,
 65		root:       e.root,
 66	}
 67
 68	if fi, ok := encoders.Load(entry); ok {
 69		return fi.(encoderFunc)
 70	}
 71
 72	// To deal with recursive types, populate the map with an
 73	// indirect func before we build it. This type waits on the
 74	// real func (f) to be ready and then calls it. This indirect
 75	// func is only used for recursive types.
 76	var (
 77		wg sync.WaitGroup
 78		f  encoderFunc
 79	)
 80	wg.Add(1)
 81	fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error {
 82		wg.Wait()
 83		return f(key, v, writer)
 84	}))
 85	if loaded {
 86		return fi.(encoderFunc)
 87	}
 88
 89	// Compute the real encoder and replace the indirect func with it.
 90	f = e.newTypeEncoder(t)
 91	wg.Done()
 92	encoders.Store(entry, f)
 93	return f
 94}
 95
 96func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
 97	if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
 98		return e.newTimeTypeEncoder()
 99	}
100	if t.ConvertibleTo(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
101		return e.newReaderTypeEncoder()
102	}
103	e.root = false
104	switch t.Kind() {
105	case reflect.Pointer:
106		inner := t.Elem()
107
108		innerEncoder := e.typeEncoder(inner)
109		return func(key string, v reflect.Value, writer *multipart.Writer) error {
110			if !v.IsValid() || v.IsNil() {
111				return nil
112			}
113			return innerEncoder(key, v.Elem(), writer)
114		}
115	case reflect.Struct:
116		return e.newStructTypeEncoder(t)
117	case reflect.Slice, reflect.Array:
118		return e.newArrayTypeEncoder(t)
119	case reflect.Map:
120		return e.newMapEncoder(t)
121	case reflect.Interface:
122		return e.newInterfaceEncoder()
123	default:
124		return e.newPrimitiveTypeEncoder(t)
125	}
126}
127
128func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
129	switch t.Kind() {
130	// Note that we could use `gjson` to encode these types but it would complicate our
131	// code more and this current code shouldn't cause any issues
132	case reflect.String:
133		return func(key string, v reflect.Value, writer *multipart.Writer) error {
134			return writer.WriteField(key, v.String())
135		}
136	case reflect.Bool:
137		return func(key string, v reflect.Value, writer *multipart.Writer) error {
138			if v.Bool() {
139				return writer.WriteField(key, "true")
140			}
141			return writer.WriteField(key, "false")
142		}
143	case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
144		return func(key string, v reflect.Value, writer *multipart.Writer) error {
145			return writer.WriteField(key, strconv.FormatInt(v.Int(), 10))
146		}
147	case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
148		return func(key string, v reflect.Value, writer *multipart.Writer) error {
149			return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10))
150		}
151	case reflect.Float32:
152		return func(key string, v reflect.Value, writer *multipart.Writer) error {
153			return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32))
154		}
155	case reflect.Float64:
156		return func(key string, v reflect.Value, writer *multipart.Writer) error {
157			return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64))
158		}
159	default:
160		return func(key string, v reflect.Value, writer *multipart.Writer) error {
161			return fmt.Errorf("unknown type received at primitive encoder: %s", t.String())
162		}
163	}
164}
165
166func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
167	itemEncoder := e.typeEncoder(t.Elem())
168
169	return func(key string, v reflect.Value, writer *multipart.Writer) error {
170		if key != "" {
171			key = key + "."
172		}
173		for i := 0; i < v.Len(); i++ {
174			err := itemEncoder(key+strconv.Itoa(i), v.Index(i), writer)
175			if err != nil {
176				return err
177			}
178		}
179		return nil
180	}
181}
182
183func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
184	if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
185		return e.newFieldTypeEncoder(t)
186	}
187
188	if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
189		return e.newRichFieldTypeEncoder(t, idx)
190	}
191
192	encoderFields := []encoderField{}
193	extraEncoder := (*encoderField)(nil)
194
195	// This helper allows us to recursively collect field encoders into a flat
196	// array. The parameter `index` keeps track of the access patterns necessary
197	// to get to some field.
198	var collectEncoderFields func(r reflect.Type, index []int)
199	collectEncoderFields = func(r reflect.Type, index []int) {
200		for i := 0; i < r.NumField(); i++ {
201			idx := append(index, i)
202			field := t.FieldByIndex(idx)
203			if !field.IsExported() {
204				continue
205			}
206			// If this is an embedded struct, traverse one level deeper to extract
207			// the field and get their encoders as well.
208			if field.Anonymous {
209				collectEncoderFields(field.Type, idx)
210				continue
211			}
212			// If json tag is not present, then we skip, which is intentionally
213			// different behavior from the stdlib.
214			ptag, ok := parseFormStructTag(field)
215			if !ok {
216				continue
217			}
218			// We only want to support unexported field if they're tagged with
219			// `extras` because that field shouldn't be part of the public API. We
220			// also want to only keep the top level extras
221			if ptag.extras && len(index) == 0 {
222				extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx}
223				continue
224			}
225			if ptag.name == "-" || ptag.name == "" {
226				continue
227			}
228
229			dateFormat, ok := parseFormatStructTag(field)
230			oldFormat := e.dateFormat
231			if ok {
232				switch dateFormat {
233				case "date-time":
234					e.dateFormat = time.RFC3339
235				case "date":
236					e.dateFormat = "2006-01-02"
237				}
238			}
239
240			var encoderFn encoderFunc
241			if ptag.omitzero {
242				typeEncoderFn := e.typeEncoder(field.Type)
243				encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error {
244					if value.IsZero() {
245						return nil
246					}
247					return typeEncoderFn(key, value, writer)
248				}
249			} else {
250				encoderFn = e.typeEncoder(field.Type)
251			}
252			encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx})
253			e.dateFormat = oldFormat
254		}
255	}
256	collectEncoderFields(t, []int{})
257
258	// Ensure deterministic output by sorting by lexicographic order
259	sort.Slice(encoderFields, func(i, j int) bool {
260		return encoderFields[i].tag.name < encoderFields[j].tag.name
261	})
262
263	return func(key string, value reflect.Value, writer *multipart.Writer) error {
264		if key != "" {
265			key = key + "."
266		}
267
268		for _, ef := range encoderFields {
269			field := value.FieldByIndex(ef.idx)
270			err := ef.fn(key+ef.tag.name, field, writer)
271			if err != nil {
272				return err
273			}
274		}
275
276		if extraEncoder != nil {
277			err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer)
278			if err != nil {
279				return err
280			}
281		}
282
283		return nil
284	}
285}
286
287func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
288	f, _ := t.FieldByName("Value")
289	enc := e.typeEncoder(f.Type)
290
291	return func(key string, value reflect.Value, writer *multipart.Writer) error {
292		present := value.FieldByName("Present")
293		if !present.Bool() {
294			return nil
295		}
296		null := value.FieldByName("Null")
297		if null.Bool() {
298			return nil
299		}
300		raw := value.FieldByName("Raw")
301		if !raw.IsNil() {
302			return e.typeEncoder(raw.Type())(key, raw, writer)
303		}
304		return enc(key, value.FieldByName("Value"), writer)
305	}
306}
307
308func (e *encoder) newTimeTypeEncoder() encoderFunc {
309	format := e.dateFormat
310	return func(key string, value reflect.Value, writer *multipart.Writer) error {
311		return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format))
312	}
313}
314
315func (e encoder) newInterfaceEncoder() encoderFunc {
316	return func(key string, value reflect.Value, writer *multipart.Writer) error {
317		value = value.Elem()
318		if !value.IsValid() {
319			return nil
320		}
321		return e.typeEncoder(value.Type())(key, value, writer)
322	}
323}
324
325var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
326
327func escapeQuotes(s string) string {
328	return quoteEscaper.Replace(s)
329}
330
331func (e *encoder) newReaderTypeEncoder() encoderFunc {
332	return func(key string, value reflect.Value, writer *multipart.Writer) error {
333		reader := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
334		filename := "anonymous_file"
335		contentType := "application/octet-stream"
336		if named, ok := reader.(interface{ Filename() string }); ok {
337			filename = named.Filename()
338		} else if named, ok := reader.(interface{ Name() string }); ok {
339			filename = path.Base(named.Name())
340		}
341		if typed, ok := reader.(interface{ ContentType() string }); ok {
342			contentType = typed.ContentType()
343		}
344
345		// Below is taken almost 1-for-1 from [multipart.CreateFormFile]
346		h := make(textproto.MIMEHeader)
347		h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename)))
348		h.Set("Content-Type", contentType)
349		filewriter, err := writer.CreatePart(h)
350		if err != nil {
351			return err
352		}
353		_, err = io.Copy(filewriter, reader)
354		return err
355	}
356}
357
358// Given a []byte of json (may either be an empty object or an object that already contains entries)
359// encode all of the entries in the map to the json byte array.
360func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error {
361	type mapPair struct {
362		key   string
363		value reflect.Value
364	}
365
366	if key != "" {
367		key = key + "."
368	}
369
370	pairs := []mapPair{}
371
372	iter := v.MapRange()
373	for iter.Next() {
374		if iter.Key().Type().Kind() == reflect.String {
375			pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()})
376		} else {
377			return fmt.Errorf("cannot encode a map with a non string key")
378		}
379	}
380
381	// Ensure deterministic output
382	sort.Slice(pairs, func(i, j int) bool {
383		return pairs[i].key < pairs[j].key
384	})
385
386	elementEncoder := e.typeEncoder(v.Type().Elem())
387	for _, p := range pairs {
388		err := elementEncoder(key+string(p.key), p.value, writer)
389		if err != nil {
390			return err
391		}
392	}
393
394	return nil
395}
396
397func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
398	return func(key string, value reflect.Value, writer *multipart.Writer) error {
399		return e.encodeMapEntries(key, value, writer)
400	}
401}