checkinit.go

  1// Copyright 2019 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"sync"
  9
 10	"google.golang.org/protobuf/internal/errors"
 11	"google.golang.org/protobuf/reflect/protoreflect"
 12	"google.golang.org/protobuf/runtime/protoiface"
 13)
 14
 15func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) {
 16	var p pointer
 17	if ms, ok := in.Message.(*messageState); ok {
 18		p = ms.pointer()
 19	} else {
 20		p = in.Message.(*messageReflectWrapper).pointer()
 21	}
 22	return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p)
 23}
 24
 25func (mi *MessageInfo) checkInitializedPointer(p pointer) error {
 26	mi.init()
 27	if !mi.needsInitCheck {
 28		return nil
 29	}
 30	if p.IsNil() {
 31		for _, f := range mi.orderedCoderFields {
 32			if f.isRequired {
 33				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
 34			}
 35		}
 36		return nil
 37	}
 38
 39	var presence presence
 40	if mi.presenceOffset.IsValid() {
 41		presence = p.Apply(mi.presenceOffset).PresenceInfo()
 42	}
 43
 44	if mi.extensionOffset.IsValid() {
 45		e := p.Apply(mi.extensionOffset).Extensions()
 46		if err := mi.isInitExtensions(e); err != nil {
 47			return err
 48		}
 49	}
 50	for _, f := range mi.orderedCoderFields {
 51		if !f.isRequired && f.funcs.isInit == nil {
 52			continue
 53		}
 54
 55		if f.presenceIndex != noPresence {
 56			if !presence.Present(f.presenceIndex) {
 57				if f.isRequired {
 58					return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
 59				}
 60				continue
 61			}
 62			if f.funcs.isInit != nil {
 63				f.mi.init()
 64				if f.mi.needsInitCheck {
 65					if f.isLazy && p.Apply(f.offset).AtomicGetPointer().IsNil() {
 66						lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
 67						if !lazy.AllowedPartial() {
 68							// Nothing to see here, it was checked on unmarshal
 69							continue
 70						}
 71						mi.lazyUnmarshal(p, f.num)
 72					}
 73					if err := f.funcs.isInit(p.Apply(f.offset), f); err != nil {
 74						return err
 75					}
 76				}
 77			}
 78			continue
 79		}
 80
 81		fptr := p.Apply(f.offset)
 82		if f.isPointer && fptr.Elem().IsNil() {
 83			if f.isRequired {
 84				return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName()))
 85			}
 86			continue
 87		}
 88		if f.funcs.isInit == nil {
 89			continue
 90		}
 91		if err := f.funcs.isInit(fptr, f); err != nil {
 92			return err
 93		}
 94	}
 95	return nil
 96}
 97
 98func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error {
 99	if ext == nil {
100		return nil
101	}
102	for _, x := range *ext {
103		ei := getExtensionFieldInfo(x.Type())
104		if ei.funcs.isInit == nil || x.isUnexpandedLazy() {
105			continue
106		}
107		v := x.Value()
108		if !v.IsValid() {
109			continue
110		}
111		if err := ei.funcs.isInit(v); err != nil {
112			return err
113		}
114	}
115	return nil
116}
117
118var (
119	needsInitCheckMu  sync.Mutex
120	needsInitCheckMap sync.Map
121)
122
123// needsInitCheck reports whether a message needs to be checked for partial initialization.
124//
125// It returns true if the message transitively includes any required or extension fields.
126func needsInitCheck(md protoreflect.MessageDescriptor) bool {
127	if v, ok := needsInitCheckMap.Load(md); ok {
128		if has, ok := v.(bool); ok {
129			return has
130		}
131	}
132	needsInitCheckMu.Lock()
133	defer needsInitCheckMu.Unlock()
134	return needsInitCheckLocked(md)
135}
136
137func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) {
138	if v, ok := needsInitCheckMap.Load(md); ok {
139		// If has is true, we've previously determined that this message
140		// needs init checks.
141		//
142		// If has is false, we've previously determined that it can never
143		// be uninitialized.
144		//
145		// If has is not a bool, we've just encountered a cycle in the
146		// message graph. In this case, it is safe to return false: If
147		// the message does have required fields, we'll detect them later
148		// in the graph traversal.
149		has, ok := v.(bool)
150		return ok && has
151	}
152	needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message
153	defer func() {
154		needsInitCheckMap.Store(md, has)
155	}()
156	if md.RequiredNumbers().Len() > 0 {
157		return true
158	}
159	if md.ExtensionRanges().Len() > 0 {
160		return true
161	}
162	for i := 0; i < md.Fields().Len(); i++ {
163		fd := md.Fields().Get(i)
164		// Map keys are never messages, so just consider the map value.
165		if fd.IsMap() {
166			fd = fd.MapValue()
167		}
168		fmd := fd.Message()
169		if fmd != nil && needsInitCheckLocked(fmd) {
170			return true
171		}
172	}
173	return false
174}