1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT license.
3
4package json
5
6import (
7 "encoding/json"
8 "fmt"
9 "reflect"
10 "strings"
11)
12
13func unmarshalStruct(jdec *json.Decoder, i interface{}) error {
14 v := reflect.ValueOf(i)
15 if v.Kind() != reflect.Ptr {
16 return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
17 }
18 v = v.Elem()
19 if v.Kind() != reflect.Struct {
20 return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
21 }
22
23 if hasUnmarshalJSON(v) {
24 // Indicates that this type has a custom Unmarshaler.
25 return jdec.Decode(v.Addr().Interface())
26 }
27
28 f := v.FieldByName(addField)
29 if f.Kind() == reflect.Invalid {
30 return fmt.Errorf("Unmarshal(%T) only supports structs that have the field AdditionalFields or implements json.Unmarshaler", i)
31 }
32
33 if f.Kind() != reflect.Map || !f.Type().AssignableTo(mapStrInterType) {
34 return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", i)
35 }
36
37 dec := newDecoder(jdec, v)
38 return dec.run()
39}
40
41type decoder struct {
42 dec *json.Decoder
43 value reflect.Value // This will be a reflect.Struct
44 translator translateFields
45 key string
46}
47
48func newDecoder(dec *json.Decoder, value reflect.Value) *decoder {
49 return &decoder{value: value, dec: dec}
50}
51
52// run runs our decoder state machine.
53func (d *decoder) run() error {
54 var state = d.start
55 var err error
56 for {
57 state, err = state()
58 if err != nil {
59 return err
60 }
61 if state == nil {
62 return nil
63 }
64 }
65}
66
67// start looks for our opening delimeter '{' and then transitions to looping through our fields.
68func (d *decoder) start() (stateFn, error) {
69 var err error
70 d.translator, err = findFields(d.value)
71 if err != nil {
72 return nil, err
73 }
74
75 delim, err := d.dec.Token()
76 if err != nil {
77 return nil, err
78 }
79 if !delimIs(delim, '{') {
80 return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim)
81 }
82
83 return d.next, nil
84}
85
86// next gets the next struct field name from the raw json or stops the machine if we get our closing }.
87func (d *decoder) next() (stateFn, error) {
88 if !d.dec.More() {
89 // Remove the closing }.
90 if _, err := d.dec.Token(); err != nil {
91 return nil, err
92 }
93 return nil, nil
94 }
95
96 key, err := d.dec.Token()
97 if err != nil {
98 return nil, err
99 }
100
101 d.key = key.(string)
102 return d.storeValue, nil
103}
104
105// storeValue takes the next value and stores it our struct. If the field can't be found
106// in the struct, it pushes the operation to storeAdditional().
107func (d *decoder) storeValue() (stateFn, error) {
108 goName := d.translator.goName(d.key)
109 if goName == "" {
110 goName = d.key
111 }
112
113 // We don't have the field in the struct, so it goes in AdditionalFields.
114 f := d.value.FieldByName(goName)
115 if f.Kind() == reflect.Invalid {
116 return d.storeAdditional, nil
117 }
118
119 // Indicates that this type has a custom Unmarshaler.
120 if hasUnmarshalJSON(f) {
121 err := d.dec.Decode(f.Addr().Interface())
122 if err != nil {
123 return nil, err
124 }
125 return d.next, nil
126 }
127
128 t, isPtr, err := fieldBaseType(d.value, goName)
129 if err != nil {
130 return nil, fmt.Errorf("type(%s) had field(%s) %w", d.value.Type().Name(), goName, err)
131 }
132
133 switch t.Kind() {
134 // We need to recursively call ourselves on any *struct or struct.
135 case reflect.Struct:
136 if isPtr {
137 if f.IsNil() {
138 f.Set(reflect.New(t))
139 }
140 } else {
141 f = f.Addr()
142 }
143 if err := unmarshalStruct(d.dec, f.Interface()); err != nil {
144 return nil, err
145 }
146 return d.next, nil
147 case reflect.Map:
148 v := reflect.MakeMap(f.Type())
149 ptr := newValue(f.Type())
150 ptr.Elem().Set(v)
151 if err := unmarshalMap(d.dec, ptr); err != nil {
152 return nil, err
153 }
154 f.Set(ptr.Elem())
155 return d.next, nil
156 case reflect.Slice:
157 v := reflect.MakeSlice(f.Type(), 0, 0)
158 ptr := newValue(f.Type())
159 ptr.Elem().Set(v)
160 if err := unmarshalSlice(d.dec, ptr); err != nil {
161 return nil, err
162 }
163 f.Set(ptr.Elem())
164 return d.next, nil
165 }
166
167 if !isPtr {
168 f = f.Addr()
169 }
170
171 // For values that are pointers, we need them to be non-nil in order
172 // to decode into them.
173 if f.IsNil() {
174 f.Set(reflect.New(t))
175 }
176
177 if err := d.dec.Decode(f.Interface()); err != nil {
178 return nil, err
179 }
180
181 return d.next, nil
182}
183
184// storeAdditional pushes the key/value into our .AdditionalFields map.
185func (d *decoder) storeAdditional() (stateFn, error) {
186 rw := json.RawMessage{}
187 if err := d.dec.Decode(&rw); err != nil {
188 return nil, err
189 }
190 field := d.value.FieldByName(addField)
191 if field.IsNil() {
192 field.Set(reflect.MakeMap(field.Type()))
193 }
194 field.SetMapIndex(reflect.ValueOf(d.key), reflect.ValueOf(rw))
195 return d.next, nil
196}
197
198func fieldBaseType(v reflect.Value, fieldName string) (t reflect.Type, isPtr bool, err error) {
199 sf, ok := v.Type().FieldByName(fieldName)
200 if !ok {
201 return nil, false, fmt.Errorf("bug: fieldBaseType() lookup of field(%s) on type(%s): do not have field", fieldName, v.Type().Name())
202 }
203 t = sf.Type
204 if t.Kind() == reflect.Ptr {
205 t = t.Elem()
206 isPtr = true
207 }
208 if t.Kind() == reflect.Ptr {
209 return nil, isPtr, fmt.Errorf("received pointer to pointer type, not supported")
210 }
211 return t, isPtr, nil
212}
213
214type translateField struct {
215 jsonName string
216 goName string
217}
218
219// translateFields is a list of translateFields with a handy lookup method.
220type translateFields []translateField
221
222// goName loops through a list of fields looking for one contaning the jsonName and
223// returning the goName. If not found, returns the empty string.
224// Note: not a map because at this size slices are faster even in tight loops.
225func (t translateFields) goName(jsonName string) string {
226 for _, entry := range t {
227 if entry.jsonName == jsonName {
228 return entry.goName
229 }
230 }
231 return ""
232}
233
234// jsonName loops through a list of fields looking for one contaning the goName and
235// returning the jsonName. If not found, returns the empty string.
236// Note: not a map because at this size slices are faster even in tight loops.
237func (t translateFields) jsonName(goName string) string {
238 for _, entry := range t {
239 if entry.goName == goName {
240 return entry.jsonName
241 }
242 }
243 return ""
244}
245
246var umarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
247
248// findFields parses a struct and writes the field tags for lookup. It will return an error
249// if any field has a type of *struct or struct that does not implement json.Marshaler.
250func findFields(v reflect.Value) (translateFields, error) {
251 if v.Kind() == reflect.Ptr {
252 v = v.Elem()
253 }
254 if v.Kind() != reflect.Struct {
255 return nil, fmt.Errorf("findFields received a %s type, expected *struct or struct", v.Type().Name())
256 }
257 tfs := make([]translateField, 0, v.NumField())
258 for i := 0; i < v.NumField(); i++ {
259 tf := translateField{
260 goName: v.Type().Field(i).Name,
261 jsonName: parseTag(v.Type().Field(i).Tag.Get("json")),
262 }
263 switch tf.jsonName {
264 case "", "-":
265 tf.jsonName = tf.goName
266 }
267 tfs = append(tfs, tf)
268
269 f := v.Field(i)
270 if f.Kind() == reflect.Ptr {
271 f = f.Elem()
272 }
273 if f.Kind() == reflect.Struct {
274 if f.Type().Implements(umarshalerType) {
275 return nil, fmt.Errorf("struct type %q which has field %q which "+
276 "doesn't implement json.Unmarshaler", v.Type().Name(), v.Type().Field(i).Name)
277 }
278 }
279 }
280 return tfs, nil
281}
282
283// parseTag just returns the first entry in the tag. tag is the string
284// returned by reflect.StructField.Tag().Get().
285func parseTag(tag string) string {
286 if idx := strings.Index(tag, ","); idx != -1 {
287 return tag[:idx]
288 }
289 return tag
290}