struct.go

  1// Copyright (c) Microsoft Corporation.
  2// Licensed under the MIT license.
  3
  4package json
  5
  6import (
  7	"encoding/json"
  8	"fmt"
  9	"reflect"
 10	"strings"
 11)
 12
 13func unmarshalStruct(jdec *json.Decoder, i interface{}) error {
 14	v := reflect.ValueOf(i)
 15	if v.Kind() != reflect.Ptr {
 16		return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
 17	}
 18	v = v.Elem()
 19	if v.Kind() != reflect.Struct {
 20		return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
 21	}
 22
 23	if hasUnmarshalJSON(v) {
 24		// Indicates that this type has a custom Unmarshaler.
 25		return jdec.Decode(v.Addr().Interface())
 26	}
 27
 28	f := v.FieldByName(addField)
 29	if f.Kind() == reflect.Invalid {
 30		return fmt.Errorf("Unmarshal(%T) only supports structs that have the field AdditionalFields or implements json.Unmarshaler", i)
 31	}
 32
 33	if f.Kind() != reflect.Map || !f.Type().AssignableTo(mapStrInterType) {
 34		return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", i)
 35	}
 36
 37	dec := newDecoder(jdec, v)
 38	return dec.run()
 39}
 40
 41type decoder struct {
 42	dec        *json.Decoder
 43	value      reflect.Value // This will be a reflect.Struct
 44	translator translateFields
 45	key        string
 46}
 47
 48func newDecoder(dec *json.Decoder, value reflect.Value) *decoder {
 49	return &decoder{value: value, dec: dec}
 50}
 51
 52// run runs our decoder state machine.
 53func (d *decoder) run() error {
 54	var state = d.start
 55	var err error
 56	for {
 57		state, err = state()
 58		if err != nil {
 59			return err
 60		}
 61		if state == nil {
 62			return nil
 63		}
 64	}
 65}
 66
 67// start looks for our opening delimeter '{' and then transitions to looping through our fields.
 68func (d *decoder) start() (stateFn, error) {
 69	var err error
 70	d.translator, err = findFields(d.value)
 71	if err != nil {
 72		return nil, err
 73	}
 74
 75	delim, err := d.dec.Token()
 76	if err != nil {
 77		return nil, err
 78	}
 79	if !delimIs(delim, '{') {
 80		return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim)
 81	}
 82
 83	return d.next, nil
 84}
 85
 86// next gets the next struct field name from the raw json or stops the machine if we get our closing }.
 87func (d *decoder) next() (stateFn, error) {
 88	if !d.dec.More() {
 89		// Remove the closing }.
 90		if _, err := d.dec.Token(); err != nil {
 91			return nil, err
 92		}
 93		return nil, nil
 94	}
 95
 96	key, err := d.dec.Token()
 97	if err != nil {
 98		return nil, err
 99	}
100
101	d.key = key.(string)
102	return d.storeValue, nil
103}
104
105// storeValue takes the next value and stores it our struct. If the field can't be found
106// in the struct, it pushes the operation to storeAdditional().
107func (d *decoder) storeValue() (stateFn, error) {
108	goName := d.translator.goName(d.key)
109	if goName == "" {
110		goName = d.key
111	}
112
113	// We don't have the field in the struct, so it goes in AdditionalFields.
114	f := d.value.FieldByName(goName)
115	if f.Kind() == reflect.Invalid {
116		return d.storeAdditional, nil
117	}
118
119	// Indicates that this type has a custom Unmarshaler.
120	if hasUnmarshalJSON(f) {
121		err := d.dec.Decode(f.Addr().Interface())
122		if err != nil {
123			return nil, err
124		}
125		return d.next, nil
126	}
127
128	t, isPtr, err := fieldBaseType(d.value, goName)
129	if err != nil {
130		return nil, fmt.Errorf("type(%s) had field(%s) %w", d.value.Type().Name(), goName, err)
131	}
132
133	switch t.Kind() {
134	// We need to recursively call ourselves on any *struct or struct.
135	case reflect.Struct:
136		if isPtr {
137			if f.IsNil() {
138				f.Set(reflect.New(t))
139			}
140		} else {
141			f = f.Addr()
142		}
143		if err := unmarshalStruct(d.dec, f.Interface()); err != nil {
144			return nil, err
145		}
146		return d.next, nil
147	case reflect.Map:
148		v := reflect.MakeMap(f.Type())
149		ptr := newValue(f.Type())
150		ptr.Elem().Set(v)
151		if err := unmarshalMap(d.dec, ptr); err != nil {
152			return nil, err
153		}
154		f.Set(ptr.Elem())
155		return d.next, nil
156	case reflect.Slice:
157		v := reflect.MakeSlice(f.Type(), 0, 0)
158		ptr := newValue(f.Type())
159		ptr.Elem().Set(v)
160		if err := unmarshalSlice(d.dec, ptr); err != nil {
161			return nil, err
162		}
163		f.Set(ptr.Elem())
164		return d.next, nil
165	}
166
167	if !isPtr {
168		f = f.Addr()
169	}
170
171	// For values that are pointers, we need them to be non-nil in order
172	// to decode into them.
173	if f.IsNil() {
174		f.Set(reflect.New(t))
175	}
176
177	if err := d.dec.Decode(f.Interface()); err != nil {
178		return nil, err
179	}
180
181	return d.next, nil
182}
183
184// storeAdditional pushes the key/value into our .AdditionalFields map.
185func (d *decoder) storeAdditional() (stateFn, error) {
186	rw := json.RawMessage{}
187	if err := d.dec.Decode(&rw); err != nil {
188		return nil, err
189	}
190	field := d.value.FieldByName(addField)
191	if field.IsNil() {
192		field.Set(reflect.MakeMap(field.Type()))
193	}
194	field.SetMapIndex(reflect.ValueOf(d.key), reflect.ValueOf(rw))
195	return d.next, nil
196}
197
198func fieldBaseType(v reflect.Value, fieldName string) (t reflect.Type, isPtr bool, err error) {
199	sf, ok := v.Type().FieldByName(fieldName)
200	if !ok {
201		return nil, false, fmt.Errorf("bug: fieldBaseType() lookup of field(%s) on type(%s): do not have field", fieldName, v.Type().Name())
202	}
203	t = sf.Type
204	if t.Kind() == reflect.Ptr {
205		t = t.Elem()
206		isPtr = true
207	}
208	if t.Kind() == reflect.Ptr {
209		return nil, isPtr, fmt.Errorf("received pointer to pointer type, not supported")
210	}
211	return t, isPtr, nil
212}
213
214type translateField struct {
215	jsonName string
216	goName   string
217}
218
219// translateFields is a list of translateFields with a handy lookup method.
220type translateFields []translateField
221
222// goName loops through a list of fields looking for one contaning the jsonName and
223// returning the goName. If not found, returns the empty string.
224// Note: not a map because at this size slices are faster even in tight loops.
225func (t translateFields) goName(jsonName string) string {
226	for _, entry := range t {
227		if entry.jsonName == jsonName {
228			return entry.goName
229		}
230	}
231	return ""
232}
233
234// jsonName loops through a list of fields looking for one contaning the goName and
235// returning the jsonName. If not found, returns the empty string.
236// Note: not a map because at this size slices are faster even in tight loops.
237func (t translateFields) jsonName(goName string) string {
238	for _, entry := range t {
239		if entry.goName == goName {
240			return entry.jsonName
241		}
242	}
243	return ""
244}
245
246var umarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
247
248// findFields parses a struct and writes the field tags for lookup. It will return an error
249// if any field has a type of *struct or struct that does not implement json.Marshaler.
250func findFields(v reflect.Value) (translateFields, error) {
251	if v.Kind() == reflect.Ptr {
252		v = v.Elem()
253	}
254	if v.Kind() != reflect.Struct {
255		return nil, fmt.Errorf("findFields received a %s type, expected *struct or struct", v.Type().Name())
256	}
257	tfs := make([]translateField, 0, v.NumField())
258	for i := 0; i < v.NumField(); i++ {
259		tf := translateField{
260			goName:   v.Type().Field(i).Name,
261			jsonName: parseTag(v.Type().Field(i).Tag.Get("json")),
262		}
263		switch tf.jsonName {
264		case "", "-":
265			tf.jsonName = tf.goName
266		}
267		tfs = append(tfs, tf)
268
269		f := v.Field(i)
270		if f.Kind() == reflect.Ptr {
271			f = f.Elem()
272		}
273		if f.Kind() == reflect.Struct {
274			if f.Type().Implements(umarshalerType) {
275				return nil, fmt.Errorf("struct type %q which has field %q which "+
276					"doesn't implement json.Unmarshaler", v.Type().Name(), v.Type().Field(i).Name)
277			}
278		}
279	}
280	return tfs, nil
281}
282
283// parseTag just returns the first entry in the tag. tag is the string
284// returned by reflect.StructField.Tag().Get().
285func parseTag(tag string) string {
286	if idx := strings.Index(tag, ","); idx != -1 {
287		return tag[:idx]
288	}
289	return tag
290}