decode.go

  1// Copyright 2019 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"math/bits"
  9
 10	"google.golang.org/protobuf/encoding/protowire"
 11	"google.golang.org/protobuf/internal/errors"
 12	"google.golang.org/protobuf/internal/flags"
 13	"google.golang.org/protobuf/proto"
 14	"google.golang.org/protobuf/reflect/protoreflect"
 15	"google.golang.org/protobuf/reflect/protoregistry"
 16	"google.golang.org/protobuf/runtime/protoiface"
 17)
 18
 19var errDecode = errors.New("cannot parse invalid wire-format data")
 20var errRecursionDepth = errors.New("exceeded maximum recursion depth")
 21
 22type unmarshalOptions struct {
 23	flags    protoiface.UnmarshalInputFlags
 24	resolver interface {
 25		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
 26		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
 27	}
 28	depth int
 29}
 30
 31func (o unmarshalOptions) Options() proto.UnmarshalOptions {
 32	return proto.UnmarshalOptions{
 33		Merge:          true,
 34		AllowPartial:   true,
 35		DiscardUnknown: o.DiscardUnknown(),
 36		Resolver:       o.resolver,
 37
 38		NoLazyDecoding: o.NoLazyDecoding(),
 39	}
 40}
 41
 42func (o unmarshalOptions) DiscardUnknown() bool {
 43	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
 44}
 45
 46func (o unmarshalOptions) AliasBuffer() bool { return o.flags&protoiface.UnmarshalAliasBuffer != 0 }
 47func (o unmarshalOptions) Validated() bool   { return o.flags&protoiface.UnmarshalValidated != 0 }
 48func (o unmarshalOptions) NoLazyDecoding() bool {
 49	return o.flags&protoiface.UnmarshalNoLazyDecoding != 0
 50}
 51
 52func (o unmarshalOptions) CanBeLazy() bool {
 53	if o.resolver != protoregistry.GlobalTypes {
 54		return false
 55	}
 56	// We ignore the UnmarshalInvalidateSizeCache even though it's not in the default set
 57	return (o.flags & ^(protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated | protoiface.UnmarshalCheckRequired)) == 0
 58}
 59
 60var lazyUnmarshalOptions = unmarshalOptions{
 61	resolver: protoregistry.GlobalTypes,
 62
 63	flags: protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated,
 64
 65	depth: protowire.DefaultRecursionLimit,
 66}
 67
 68type unmarshalOutput struct {
 69	n           int // number of bytes consumed
 70	initialized bool
 71}
 72
 73// unmarshal is protoreflect.Methods.Unmarshal.
 74func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
 75	var p pointer
 76	if ms, ok := in.Message.(*messageState); ok {
 77		p = ms.pointer()
 78	} else {
 79		p = in.Message.(*messageReflectWrapper).pointer()
 80	}
 81	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
 82		flags:    in.Flags,
 83		resolver: in.Resolver,
 84		depth:    in.Depth,
 85	})
 86	var flags protoiface.UnmarshalOutputFlags
 87	if out.initialized {
 88		flags |= protoiface.UnmarshalInitialized
 89	}
 90	return protoiface.UnmarshalOutput{
 91		Flags: flags,
 92	}, err
 93}
 94
 95// errUnknown is returned during unmarshaling to indicate a parse error that
 96// should result in a field being placed in the unknown fields section (for example,
 97// when the wire type doesn't match) as opposed to the entire unmarshal operation
 98// failing (for example, when a field extends past the available input).
 99//
100// This is a sentinel error which should never be visible to the user.
101var errUnknown = errors.New("unknown")
102
103func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
104	mi.init()
105	opts.depth--
106	if opts.depth < 0 {
107		return out, errRecursionDepth
108	}
109	if flags.ProtoLegacy && mi.isMessageSet {
110		return unmarshalMessageSet(mi, b, p, opts)
111	}
112
113	lazyDecoding := LazyEnabled() // default
114	if opts.NoLazyDecoding() {
115		lazyDecoding = false // explicitly disabled
116	}
117	if mi.lazyOffset.IsValid() && lazyDecoding {
118		return mi.unmarshalPointerLazy(b, p, groupTag, opts)
119	}
120	return mi.unmarshalPointerEager(b, p, groupTag, opts)
121}
122
123// unmarshalPointerEager is the message unmarshalling function for all messages that are not lazy.
124// The corresponding function for Lazy is in google_lazy.go.
125func (mi *MessageInfo) unmarshalPointerEager(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
126
127	initialized := true
128	var requiredMask uint64
129	var exts *map[int32]ExtensionField
130
131	var presence presence
132	if mi.presenceOffset.IsValid() {
133		presence = p.Apply(mi.presenceOffset).PresenceInfo()
134	}
135
136	start := len(b)
137	for len(b) > 0 {
138		// Parse the tag (field number and wire type).
139		var tag uint64
140		if b[0] < 0x80 {
141			tag = uint64(b[0])
142			b = b[1:]
143		} else if len(b) >= 2 && b[1] < 128 {
144			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
145			b = b[2:]
146		} else {
147			var n int
148			tag, n = protowire.ConsumeVarint(b)
149			if n < 0 {
150				return out, errDecode
151			}
152			b = b[n:]
153		}
154		var num protowire.Number
155		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
156			return out, errDecode
157		} else {
158			num = protowire.Number(n)
159		}
160		wtyp := protowire.Type(tag & 7)
161
162		if wtyp == protowire.EndGroupType {
163			if num != groupTag {
164				return out, errDecode
165			}
166			groupTag = 0
167			break
168		}
169
170		var f *coderFieldInfo
171		if int(num) < len(mi.denseCoderFields) {
172			f = mi.denseCoderFields[num]
173		} else {
174			f = mi.coderFields[num]
175		}
176		var n int
177		err := errUnknown
178		switch {
179		case f != nil:
180			if f.funcs.unmarshal == nil {
181				break
182			}
183			var o unmarshalOutput
184			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
185			n = o.n
186			if err != nil {
187				break
188			}
189			requiredMask |= f.validation.requiredBit
190			if f.funcs.isInit != nil && !o.initialized {
191				initialized = false
192			}
193
194			if f.presenceIndex != noPresence {
195				presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
196			}
197
198		default:
199			// Possible extension.
200			if exts == nil && mi.extensionOffset.IsValid() {
201				exts = p.Apply(mi.extensionOffset).Extensions()
202				if *exts == nil {
203					*exts = make(map[int32]ExtensionField)
204				}
205			}
206			if exts == nil {
207				break
208			}
209			var o unmarshalOutput
210			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
211			if err != nil {
212				break
213			}
214			n = o.n
215			if !o.initialized {
216				initialized = false
217			}
218		}
219		if err != nil {
220			if err != errUnknown {
221				return out, err
222			}
223			n = protowire.ConsumeFieldValue(num, wtyp, b)
224			if n < 0 {
225				return out, errDecode
226			}
227			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
228				u := mi.mutableUnknownBytes(p)
229				*u = protowire.AppendTag(*u, num, wtyp)
230				*u = append(*u, b[:n]...)
231			}
232		}
233		b = b[n:]
234	}
235	if groupTag != 0 {
236		return out, errDecode
237	}
238	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
239		initialized = false
240	}
241	if initialized {
242		out.initialized = true
243	}
244	out.n = start - len(b)
245	return out, nil
246}
247
248func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
249	x := exts[int32(num)]
250	xt := x.Type()
251	if xt == nil {
252		var err error
253		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
254		if err != nil {
255			if err == protoregistry.NotFound {
256				return out, errUnknown
257			}
258			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
259		}
260	}
261	xi := getExtensionFieldInfo(xt)
262	if xi.funcs.unmarshal == nil {
263		return out, errUnknown
264	}
265	if flags.LazyUnmarshalExtensions {
266		if opts.CanBeLazy() && x.canLazy(xt) {
267			out, valid := skipExtension(b, xi, num, wtyp, opts)
268			switch valid {
269			case ValidationValid:
270				if out.initialized {
271					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
272					exts[int32(num)] = x
273					return out, nil
274				}
275			case ValidationInvalid:
276				return out, errDecode
277			case ValidationUnknown:
278			}
279		}
280	}
281	ival := x.Value()
282	if !ival.IsValid() && xi.unmarshalNeedsValue {
283		// Create a new message, list, or map value to fill in.
284		// For enums, create a prototype value to let the unmarshal func know the
285		// concrete type.
286		ival = xt.New()
287	}
288	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
289	if err != nil {
290		return out, err
291	}
292	if xi.funcs.isInit == nil {
293		out.initialized = true
294	}
295	x.Set(xt, v)
296	exts[int32(num)] = x
297	return out, nil
298}
299
300func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
301	if xi.validation.mi == nil {
302		return out, ValidationUnknown
303	}
304	xi.validation.mi.init()
305	switch xi.validation.typ {
306	case validationTypeMessage:
307		if wtyp != protowire.BytesType {
308			return out, ValidationUnknown
309		}
310		v, n := protowire.ConsumeBytes(b)
311		if n < 0 {
312			return out, ValidationUnknown
313		}
314
315		if opts.Validated() {
316			out.initialized = true
317			out.n = n
318			return out, ValidationValid
319		}
320
321		out, st := xi.validation.mi.validate(v, 0, opts)
322		out.n = n
323		return out, st
324	case validationTypeGroup:
325		if wtyp != protowire.StartGroupType {
326			return out, ValidationUnknown
327		}
328		out, st := xi.validation.mi.validate(b, num, opts)
329		return out, st
330	default:
331		return out, ValidationUnknown
332	}
333}