validate.go

  1// Copyright 2019 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"fmt"
  9	"math"
 10	"math/bits"
 11	"reflect"
 12	"unicode/utf8"
 13
 14	"google.golang.org/protobuf/encoding/protowire"
 15	"google.golang.org/protobuf/internal/encoding/messageset"
 16	"google.golang.org/protobuf/internal/flags"
 17	"google.golang.org/protobuf/internal/genid"
 18	"google.golang.org/protobuf/internal/strs"
 19	"google.golang.org/protobuf/reflect/protoreflect"
 20	"google.golang.org/protobuf/reflect/protoregistry"
 21	"google.golang.org/protobuf/runtime/protoiface"
 22)
 23
 24// ValidationStatus is the result of validating the wire-format encoding of a message.
 25type ValidationStatus int
 26
 27const (
 28	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
 29	// The validator was unable to render a judgement.
 30	//
 31	// The only causes of this status are an aberrant message type appearing somewhere
 32	// in the message or a failure in the extension resolver.
 33	ValidationUnknown ValidationStatus = iota + 1
 34
 35	// ValidationInvalid indicates that unmarshaling the message will fail.
 36	ValidationInvalid
 37
 38	// ValidationValid indicates that unmarshaling the message will succeed.
 39	ValidationValid
 40
 41	// ValidationWrongWireType indicates that a validated field does not have
 42	// the expected wire type.
 43	ValidationWrongWireType
 44)
 45
 46func (v ValidationStatus) String() string {
 47	switch v {
 48	case ValidationUnknown:
 49		return "ValidationUnknown"
 50	case ValidationInvalid:
 51		return "ValidationInvalid"
 52	case ValidationValid:
 53		return "ValidationValid"
 54	default:
 55		return fmt.Sprintf("ValidationStatus(%d)", int(v))
 56	}
 57}
 58
 59// Validate determines whether the contents of the buffer are a valid wire encoding
 60// of the message type.
 61//
 62// This function is exposed for testing.
 63func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
 64	mi, ok := mt.(*MessageInfo)
 65	if !ok {
 66		return out, ValidationUnknown
 67	}
 68	if in.Resolver == nil {
 69		in.Resolver = protoregistry.GlobalTypes
 70	}
 71	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
 72		flags:    in.Flags,
 73		resolver: in.Resolver,
 74	})
 75	if o.initialized {
 76		out.Flags |= protoiface.UnmarshalInitialized
 77	}
 78	return out, st
 79}
 80
 81type validationInfo struct {
 82	mi               *MessageInfo
 83	typ              validationType
 84	keyType, valType validationType
 85
 86	// For non-required fields, requiredBit is 0.
 87	//
 88	// For required fields, requiredBit's nth bit is set, where n is a
 89	// unique index in the range [0, MessageInfo.numRequiredFields).
 90	//
 91	// If there are more than 64 required fields, requiredBit is 0.
 92	requiredBit uint64
 93}
 94
 95type validationType uint8
 96
 97const (
 98	validationTypeOther validationType = iota
 99	validationTypeMessage
100	validationTypeGroup
101	validationTypeMap
102	validationTypeRepeatedVarint
103	validationTypeRepeatedFixed32
104	validationTypeRepeatedFixed64
105	validationTypeVarint
106	validationTypeFixed32
107	validationTypeFixed64
108	validationTypeBytes
109	validationTypeUTF8String
110	validationTypeMessageSetItem
111)
112
113func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
114	var vi validationInfo
115	switch {
116	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
117		switch fd.Kind() {
118		case protoreflect.MessageKind:
119			vi.typ = validationTypeMessage
120			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
121				vi.mi = getMessageInfo(ot.Field(0).Type)
122			}
123		case protoreflect.GroupKind:
124			vi.typ = validationTypeGroup
125			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
126				vi.mi = getMessageInfo(ot.Field(0).Type)
127			}
128		case protoreflect.StringKind:
129			if strs.EnforceUTF8(fd) {
130				vi.typ = validationTypeUTF8String
131			}
132		}
133	default:
134		vi = newValidationInfo(fd, ft)
135	}
136	if fd.Cardinality() == protoreflect.Required {
137		// Avoid overflow. The required field check is done with a 64-bit mask, with
138		// any message containing more than 64 required fields always reported as
139		// potentially uninitialized, so it is not important to get a precise count
140		// of the required fields past 64.
141		if mi.numRequiredFields < math.MaxUint8 {
142			mi.numRequiredFields++
143			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
144		}
145	}
146	return vi
147}
148
149func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
150	var vi validationInfo
151	switch {
152	case fd.IsList():
153		switch fd.Kind() {
154		case protoreflect.MessageKind:
155			vi.typ = validationTypeMessage
156
157			if ft.Kind() == reflect.Ptr {
158				// Repeated opaque message fields are *[]*T.
159				ft = ft.Elem()
160			}
161
162			if ft.Kind() == reflect.Slice {
163				vi.mi = getMessageInfo(ft.Elem())
164			}
165		case protoreflect.GroupKind:
166			vi.typ = validationTypeGroup
167
168			if ft.Kind() == reflect.Ptr {
169				// Repeated opaque message fields are *[]*T.
170				ft = ft.Elem()
171			}
172
173			if ft.Kind() == reflect.Slice {
174				vi.mi = getMessageInfo(ft.Elem())
175			}
176		case protoreflect.StringKind:
177			vi.typ = validationTypeBytes
178			if strs.EnforceUTF8(fd) {
179				vi.typ = validationTypeUTF8String
180			}
181		default:
182			switch wireTypes[fd.Kind()] {
183			case protowire.VarintType:
184				vi.typ = validationTypeRepeatedVarint
185			case protowire.Fixed32Type:
186				vi.typ = validationTypeRepeatedFixed32
187			case protowire.Fixed64Type:
188				vi.typ = validationTypeRepeatedFixed64
189			}
190		}
191	case fd.IsMap():
192		vi.typ = validationTypeMap
193		switch fd.MapKey().Kind() {
194		case protoreflect.StringKind:
195			if strs.EnforceUTF8(fd) {
196				vi.keyType = validationTypeUTF8String
197			}
198		}
199		switch fd.MapValue().Kind() {
200		case protoreflect.MessageKind:
201			vi.valType = validationTypeMessage
202			if ft.Kind() == reflect.Map {
203				vi.mi = getMessageInfo(ft.Elem())
204			}
205		case protoreflect.StringKind:
206			if strs.EnforceUTF8(fd) {
207				vi.valType = validationTypeUTF8String
208			}
209		}
210	default:
211		switch fd.Kind() {
212		case protoreflect.MessageKind:
213			vi.typ = validationTypeMessage
214			vi.mi = getMessageInfo(ft)
215		case protoreflect.GroupKind:
216			vi.typ = validationTypeGroup
217			vi.mi = getMessageInfo(ft)
218		case protoreflect.StringKind:
219			vi.typ = validationTypeBytes
220			if strs.EnforceUTF8(fd) {
221				vi.typ = validationTypeUTF8String
222			}
223		default:
224			switch wireTypes[fd.Kind()] {
225			case protowire.VarintType:
226				vi.typ = validationTypeVarint
227			case protowire.Fixed32Type:
228				vi.typ = validationTypeFixed32
229			case protowire.Fixed64Type:
230				vi.typ = validationTypeFixed64
231			case protowire.BytesType:
232				vi.typ = validationTypeBytes
233			}
234		}
235	}
236	return vi
237}
238
239func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
240	mi.init()
241	type validationState struct {
242		typ              validationType
243		keyType, valType validationType
244		endGroup         protowire.Number
245		mi               *MessageInfo
246		tail             []byte
247		requiredMask     uint64
248	}
249
250	// Pre-allocate some slots to avoid repeated slice reallocation.
251	states := make([]validationState, 0, 16)
252	states = append(states, validationState{
253		typ: validationTypeMessage,
254		mi:  mi,
255	})
256	if groupTag > 0 {
257		states[0].typ = validationTypeGroup
258		states[0].endGroup = groupTag
259	}
260	initialized := true
261	start := len(b)
262State:
263	for len(states) > 0 {
264		st := &states[len(states)-1]
265		for len(b) > 0 {
266			// Parse the tag (field number and wire type).
267			var tag uint64
268			if b[0] < 0x80 {
269				tag = uint64(b[0])
270				b = b[1:]
271			} else if len(b) >= 2 && b[1] < 128 {
272				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
273				b = b[2:]
274			} else {
275				var n int
276				tag, n = protowire.ConsumeVarint(b)
277				if n < 0 {
278					return out, ValidationInvalid
279				}
280				b = b[n:]
281			}
282			var num protowire.Number
283			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
284				return out, ValidationInvalid
285			} else {
286				num = protowire.Number(n)
287			}
288			wtyp := protowire.Type(tag & 7)
289
290			if wtyp == protowire.EndGroupType {
291				if st.endGroup == num {
292					goto PopState
293				}
294				return out, ValidationInvalid
295			}
296			var vi validationInfo
297			switch {
298			case st.typ == validationTypeMap:
299				switch num {
300				case genid.MapEntry_Key_field_number:
301					vi.typ = st.keyType
302				case genid.MapEntry_Value_field_number:
303					vi.typ = st.valType
304					vi.mi = st.mi
305					vi.requiredBit = 1
306				}
307			case flags.ProtoLegacy && st.mi.isMessageSet:
308				switch num {
309				case messageset.FieldItem:
310					vi.typ = validationTypeMessageSetItem
311				}
312			default:
313				var f *coderFieldInfo
314				if int(num) < len(st.mi.denseCoderFields) {
315					f = st.mi.denseCoderFields[num]
316				} else {
317					f = st.mi.coderFields[num]
318				}
319				if f != nil {
320					vi = f.validation
321					break
322				}
323				// Possible extension field.
324				//
325				// TODO: We should return ValidationUnknown when:
326				//   1. The resolver is not frozen. (More extensions may be added to it.)
327				//   2. The resolver returns preg.NotFound.
328				// In this case, a type added to the resolver in the future could cause
329				// unmarshaling to begin failing. Supporting this requires some way to
330				// determine if the resolver is frozen.
331				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
332				if err != nil && err != protoregistry.NotFound {
333					return out, ValidationUnknown
334				}
335				if err == nil {
336					vi = getExtensionFieldInfo(xt).validation
337				}
338			}
339			if vi.requiredBit != 0 {
340				// Check that the field has a compatible wire type.
341				// We only need to consider non-repeated field types,
342				// since repeated fields (and maps) can never be required.
343				ok := false
344				switch vi.typ {
345				case validationTypeVarint:
346					ok = wtyp == protowire.VarintType
347				case validationTypeFixed32:
348					ok = wtyp == protowire.Fixed32Type
349				case validationTypeFixed64:
350					ok = wtyp == protowire.Fixed64Type
351				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
352					ok = wtyp == protowire.BytesType
353				case validationTypeGroup:
354					ok = wtyp == protowire.StartGroupType
355				}
356				if ok {
357					st.requiredMask |= vi.requiredBit
358				}
359			}
360
361			switch wtyp {
362			case protowire.VarintType:
363				if len(b) >= 10 {
364					switch {
365					case b[0] < 0x80:
366						b = b[1:]
367					case b[1] < 0x80:
368						b = b[2:]
369					case b[2] < 0x80:
370						b = b[3:]
371					case b[3] < 0x80:
372						b = b[4:]
373					case b[4] < 0x80:
374						b = b[5:]
375					case b[5] < 0x80:
376						b = b[6:]
377					case b[6] < 0x80:
378						b = b[7:]
379					case b[7] < 0x80:
380						b = b[8:]
381					case b[8] < 0x80:
382						b = b[9:]
383					case b[9] < 0x80 && b[9] < 2:
384						b = b[10:]
385					default:
386						return out, ValidationInvalid
387					}
388				} else {
389					switch {
390					case len(b) > 0 && b[0] < 0x80:
391						b = b[1:]
392					case len(b) > 1 && b[1] < 0x80:
393						b = b[2:]
394					case len(b) > 2 && b[2] < 0x80:
395						b = b[3:]
396					case len(b) > 3 && b[3] < 0x80:
397						b = b[4:]
398					case len(b) > 4 && b[4] < 0x80:
399						b = b[5:]
400					case len(b) > 5 && b[5] < 0x80:
401						b = b[6:]
402					case len(b) > 6 && b[6] < 0x80:
403						b = b[7:]
404					case len(b) > 7 && b[7] < 0x80:
405						b = b[8:]
406					case len(b) > 8 && b[8] < 0x80:
407						b = b[9:]
408					case len(b) > 9 && b[9] < 2:
409						b = b[10:]
410					default:
411						return out, ValidationInvalid
412					}
413				}
414				continue State
415			case protowire.BytesType:
416				var size uint64
417				if len(b) >= 1 && b[0] < 0x80 {
418					size = uint64(b[0])
419					b = b[1:]
420				} else if len(b) >= 2 && b[1] < 128 {
421					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
422					b = b[2:]
423				} else {
424					var n int
425					size, n = protowire.ConsumeVarint(b)
426					if n < 0 {
427						return out, ValidationInvalid
428					}
429					b = b[n:]
430				}
431				if size > uint64(len(b)) {
432					return out, ValidationInvalid
433				}
434				v := b[:size]
435				b = b[size:]
436				switch vi.typ {
437				case validationTypeMessage:
438					if vi.mi == nil {
439						return out, ValidationUnknown
440					}
441					vi.mi.init()
442					fallthrough
443				case validationTypeMap:
444					if vi.mi != nil {
445						vi.mi.init()
446					}
447					states = append(states, validationState{
448						typ:     vi.typ,
449						keyType: vi.keyType,
450						valType: vi.valType,
451						mi:      vi.mi,
452						tail:    b,
453					})
454					b = v
455					continue State
456				case validationTypeRepeatedVarint:
457					// Packed field.
458					for len(v) > 0 {
459						_, n := protowire.ConsumeVarint(v)
460						if n < 0 {
461							return out, ValidationInvalid
462						}
463						v = v[n:]
464					}
465				case validationTypeRepeatedFixed32:
466					// Packed field.
467					if len(v)%4 != 0 {
468						return out, ValidationInvalid
469					}
470				case validationTypeRepeatedFixed64:
471					// Packed field.
472					if len(v)%8 != 0 {
473						return out, ValidationInvalid
474					}
475				case validationTypeUTF8String:
476					if !utf8.Valid(v) {
477						return out, ValidationInvalid
478					}
479				}
480			case protowire.Fixed32Type:
481				if len(b) < 4 {
482					return out, ValidationInvalid
483				}
484				b = b[4:]
485			case protowire.Fixed64Type:
486				if len(b) < 8 {
487					return out, ValidationInvalid
488				}
489				b = b[8:]
490			case protowire.StartGroupType:
491				switch {
492				case vi.typ == validationTypeGroup:
493					if vi.mi == nil {
494						return out, ValidationUnknown
495					}
496					vi.mi.init()
497					states = append(states, validationState{
498						typ:      validationTypeGroup,
499						mi:       vi.mi,
500						endGroup: num,
501					})
502					continue State
503				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
504					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
505					if err != nil {
506						return out, ValidationInvalid
507					}
508					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
509					switch {
510					case err == protoregistry.NotFound:
511						b = b[n:]
512					case err != nil:
513						return out, ValidationUnknown
514					default:
515						xvi := getExtensionFieldInfo(xt).validation
516						if xvi.mi != nil {
517							xvi.mi.init()
518						}
519						states = append(states, validationState{
520							typ:  xvi.typ,
521							mi:   xvi.mi,
522							tail: b[n:],
523						})
524						b = v
525						continue State
526					}
527				default:
528					n := protowire.ConsumeFieldValue(num, wtyp, b)
529					if n < 0 {
530						return out, ValidationInvalid
531					}
532					b = b[n:]
533				}
534			default:
535				return out, ValidationInvalid
536			}
537		}
538		if st.endGroup != 0 {
539			return out, ValidationInvalid
540		}
541		if len(b) != 0 {
542			return out, ValidationInvalid
543		}
544		b = st.tail
545	PopState:
546		numRequiredFields := 0
547		switch st.typ {
548		case validationTypeMessage, validationTypeGroup:
549			numRequiredFields = int(st.mi.numRequiredFields)
550		case validationTypeMap:
551			// If this is a map field with a message value that contains
552			// required fields, require that the value be present.
553			if st.mi != nil && st.mi.numRequiredFields > 0 {
554				numRequiredFields = 1
555			}
556		}
557		// If there are more than 64 required fields, this check will
558		// always fail and we will report that the message is potentially
559		// uninitialized.
560		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
561			initialized = false
562		}
563		states = states[:len(states)-1]
564	}
565	out.n = start - len(b)
566	if initialized {
567		out.initialized = true
568	}
569	return out, ValidationValid
570}