codec_map.go

  1// Copyright 2019 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"reflect"
  9	"sort"
 10
 11	"google.golang.org/protobuf/encoding/protowire"
 12	"google.golang.org/protobuf/internal/errors"
 13	"google.golang.org/protobuf/internal/genid"
 14	"google.golang.org/protobuf/reflect/protoreflect"
 15)
 16
 17type mapInfo struct {
 18	goType     reflect.Type
 19	keyWiretag uint64
 20	valWiretag uint64
 21	keyFuncs   valueCoderFuncs
 22	valFuncs   valueCoderFuncs
 23	keyZero    protoreflect.Value
 24	keyKind    protoreflect.Kind
 25	conv       *mapConverter
 26}
 27
 28func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
 29	// TODO: Consider generating specialized map coders.
 30	keyField := fd.MapKey()
 31	valField := fd.MapValue()
 32	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
 33	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
 34	keyFuncs := encoderFuncsForValue(keyField)
 35	valFuncs := encoderFuncsForValue(valField)
 36	conv := newMapConverter(ft, fd)
 37
 38	mapi := &mapInfo{
 39		goType:     ft,
 40		keyWiretag: keyWiretag,
 41		valWiretag: valWiretag,
 42		keyFuncs:   keyFuncs,
 43		valFuncs:   valFuncs,
 44		keyZero:    keyField.Default(),
 45		keyKind:    keyField.Kind(),
 46		conv:       conv,
 47	}
 48	if valField.Kind() == protoreflect.MessageKind {
 49		valueMessage = getMessageInfo(ft.Elem())
 50	}
 51
 52	funcs = pointerCoderFuncs{
 53		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
 54			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
 55		},
 56		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 57			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
 58		},
 59		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
 60			mp := p.AsValueOf(ft)
 61			if mp.Elem().IsNil() {
 62				mp.Elem().Set(reflect.MakeMap(mapi.goType))
 63			}
 64			if f.mi == nil {
 65				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
 66			} else {
 67				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
 68			}
 69		},
 70	}
 71	switch valField.Kind() {
 72	case protoreflect.MessageKind:
 73		funcs.merge = mergeMapOfMessage
 74	case protoreflect.BytesKind:
 75		funcs.merge = mergeMapOfBytes
 76	default:
 77		funcs.merge = mergeMap
 78	}
 79	if valFuncs.isInit != nil {
 80		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
 81			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
 82		}
 83	}
 84	return valueMessage, funcs
 85}
 86
 87const (
 88	mapKeyTagSize = 1 // field 1, tag size 1.
 89	mapValTagSize = 1 // field 2, tag size 2.
 90)
 91
 92func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
 93	if mapv.Len() == 0 {
 94		return 0
 95	}
 96	n := 0
 97	iter := mapv.MapRange()
 98	for iter.Next() {
 99		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
100		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
101		var valSize int
102		value := mapi.conv.valConv.PBValueOf(iter.Value())
103		if f.mi == nil {
104			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
105		} else {
106			p := pointerOfValue(iter.Value())
107			valSize += mapValTagSize
108			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
109		}
110		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
111	}
112	return n
113}
114
115func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
116	if wtyp != protowire.BytesType {
117		return out, errUnknown
118	}
119	b, n := protowire.ConsumeBytes(b)
120	if n < 0 {
121		return out, errDecode
122	}
123	var (
124		key = mapi.keyZero
125		val = mapi.conv.valConv.New()
126	)
127	for len(b) > 0 {
128		num, wtyp, n := protowire.ConsumeTag(b)
129		if n < 0 {
130			return out, errDecode
131		}
132		if num > protowire.MaxValidNumber {
133			return out, errDecode
134		}
135		b = b[n:]
136		err := errUnknown
137		switch num {
138		case genid.MapEntry_Key_field_number:
139			var v protoreflect.Value
140			var o unmarshalOutput
141			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
142			if err != nil {
143				break
144			}
145			key = v
146			n = o.n
147		case genid.MapEntry_Value_field_number:
148			var v protoreflect.Value
149			var o unmarshalOutput
150			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
151			if err != nil {
152				break
153			}
154			val = v
155			n = o.n
156		}
157		if err == errUnknown {
158			n = protowire.ConsumeFieldValue(num, wtyp, b)
159			if n < 0 {
160				return out, errDecode
161			}
162		} else if err != nil {
163			return out, err
164		}
165		b = b[n:]
166	}
167	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
168	out.n = n
169	return out, nil
170}
171
172func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
173	if wtyp != protowire.BytesType {
174		return out, errUnknown
175	}
176	b, n := protowire.ConsumeBytes(b)
177	if n < 0 {
178		return out, errDecode
179	}
180	var (
181		key = mapi.keyZero
182		val = reflect.New(f.mi.GoReflectType.Elem())
183	)
184	for len(b) > 0 {
185		num, wtyp, n := protowire.ConsumeTag(b)
186		if n < 0 {
187			return out, errDecode
188		}
189		if num > protowire.MaxValidNumber {
190			return out, errDecode
191		}
192		b = b[n:]
193		err := errUnknown
194		switch num {
195		case 1:
196			var v protoreflect.Value
197			var o unmarshalOutput
198			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
199			if err != nil {
200				break
201			}
202			key = v
203			n = o.n
204		case 2:
205			if wtyp != protowire.BytesType {
206				break
207			}
208			var v []byte
209			v, n = protowire.ConsumeBytes(b)
210			if n < 0 {
211				return out, errDecode
212			}
213			var o unmarshalOutput
214			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
215			if o.initialized {
216				// Consider this map item initialized so long as we see
217				// an initialized value.
218				out.initialized = true
219			}
220		}
221		if err == errUnknown {
222			n = protowire.ConsumeFieldValue(num, wtyp, b)
223			if n < 0 {
224				return out, errDecode
225			}
226		} else if err != nil {
227			return out, err
228		}
229		b = b[n:]
230	}
231	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
232	out.n = n
233	return out, nil
234}
235
236func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
237	if f.mi == nil {
238		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
239		val := mapi.conv.valConv.PBValueOf(valrv)
240		size := 0
241		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
242		size += mapi.valFuncs.size(val, mapValTagSize, opts)
243		b = protowire.AppendVarint(b, uint64(size))
244		before := len(b)
245		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
246		if err != nil {
247			return nil, err
248		}
249		b, err = mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
250		if measuredSize := len(b) - before; size != measuredSize && err == nil {
251			return nil, errors.MismatchedSizeCalculation(size, measuredSize)
252		}
253		return b, err
254	} else {
255		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
256		val := pointerOfValue(valrv)
257		valSize := f.mi.sizePointer(val, opts)
258		size := 0
259		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
260		size += mapValTagSize + protowire.SizeBytes(valSize)
261		b = protowire.AppendVarint(b, uint64(size))
262		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
263		if err != nil {
264			return nil, err
265		}
266		b = protowire.AppendVarint(b, mapi.valWiretag)
267		b = protowire.AppendVarint(b, uint64(valSize))
268		before := len(b)
269		b, err = f.mi.marshalAppendPointer(b, val, opts)
270		if measuredSize := len(b) - before; valSize != measuredSize && err == nil {
271			return nil, errors.MismatchedSizeCalculation(valSize, measuredSize)
272		}
273		return b, err
274	}
275}
276
277func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
278	if mapv.Len() == 0 {
279		return b, nil
280	}
281	if opts.Deterministic() {
282		return appendMapDeterministic(b, mapv, mapi, f, opts)
283	}
284	iter := mapv.MapRange()
285	for iter.Next() {
286		var err error
287		b = protowire.AppendVarint(b, f.wiretag)
288		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
289		if err != nil {
290			return b, err
291		}
292	}
293	return b, nil
294}
295
296func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
297	keys := mapv.MapKeys()
298	sort.Slice(keys, func(i, j int) bool {
299		switch keys[i].Kind() {
300		case reflect.Bool:
301			return !keys[i].Bool() && keys[j].Bool()
302		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
303			return keys[i].Int() < keys[j].Int()
304		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
305			return keys[i].Uint() < keys[j].Uint()
306		case reflect.Float32, reflect.Float64:
307			return keys[i].Float() < keys[j].Float()
308		case reflect.String:
309			return keys[i].String() < keys[j].String()
310		default:
311			panic("invalid kind: " + keys[i].Kind().String())
312		}
313	})
314	for _, key := range keys {
315		var err error
316		b = protowire.AppendVarint(b, f.wiretag)
317		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
318		if err != nil {
319			return b, err
320		}
321	}
322	return b, nil
323}
324
325func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
326	if mi := f.mi; mi != nil {
327		mi.init()
328		if !mi.needsInitCheck {
329			return nil
330		}
331		iter := mapv.MapRange()
332		for iter.Next() {
333			val := pointerOfValue(iter.Value())
334			if err := mi.checkInitializedPointer(val); err != nil {
335				return err
336			}
337		}
338	} else {
339		iter := mapv.MapRange()
340		for iter.Next() {
341			val := mapi.conv.valConv.PBValueOf(iter.Value())
342			if err := mapi.valFuncs.isInit(val); err != nil {
343				return err
344			}
345		}
346	}
347	return nil
348}
349
350func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
351	dstm := dst.AsValueOf(f.ft).Elem()
352	srcm := src.AsValueOf(f.ft).Elem()
353	if srcm.Len() == 0 {
354		return
355	}
356	if dstm.IsNil() {
357		dstm.Set(reflect.MakeMap(f.ft))
358	}
359	iter := srcm.MapRange()
360	for iter.Next() {
361		dstm.SetMapIndex(iter.Key(), iter.Value())
362	}
363}
364
365func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
366	dstm := dst.AsValueOf(f.ft).Elem()
367	srcm := src.AsValueOf(f.ft).Elem()
368	if srcm.Len() == 0 {
369		return
370	}
371	if dstm.IsNil() {
372		dstm.Set(reflect.MakeMap(f.ft))
373	}
374	iter := srcm.MapRange()
375	for iter.Next() {
376		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
377	}
378}
379
380func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
381	dstm := dst.AsValueOf(f.ft).Elem()
382	srcm := src.AsValueOf(f.ft).Elem()
383	if srcm.Len() == 0 {
384		return
385	}
386	if dstm.IsNil() {
387		dstm.Set(reflect.MakeMap(f.ft))
388	}
389	iter := srcm.MapRange()
390	for iter.Next() {
391		val := reflect.New(f.ft.Elem().Elem())
392		if f.mi != nil {
393			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
394		} else {
395			opts.Merge(asMessage(val), asMessage(iter.Value()))
396		}
397		dstm.SetMapIndex(iter.Key(), val)
398	}
399}