lazy.go

  1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"fmt"
  9	"math/bits"
 10	"os"
 11	"reflect"
 12	"sort"
 13	"sync/atomic"
 14
 15	"google.golang.org/protobuf/encoding/protowire"
 16	"google.golang.org/protobuf/internal/errors"
 17	"google.golang.org/protobuf/internal/protolazy"
 18	"google.golang.org/protobuf/reflect/protoreflect"
 19	preg "google.golang.org/protobuf/reflect/protoregistry"
 20	piface "google.golang.org/protobuf/runtime/protoiface"
 21)
 22
 23var enableLazy int32 = func() int32 {
 24	if os.Getenv("GOPROTODEBUG") == "nolazy" {
 25		return 0
 26	}
 27	return 1
 28}()
 29
 30// EnableLazyUnmarshal enables lazy unmarshaling.
 31func EnableLazyUnmarshal(enable bool) {
 32	if enable {
 33		atomic.StoreInt32(&enableLazy, 1)
 34		return
 35	}
 36	atomic.StoreInt32(&enableLazy, 0)
 37}
 38
 39// LazyEnabled reports whether lazy unmarshalling is currently enabled.
 40func LazyEnabled() bool {
 41	return atomic.LoadInt32(&enableLazy) != 0
 42}
 43
 44// UnmarshalField unmarshals a field in a message.
 45func UnmarshalField(m interface{}, num protowire.Number) {
 46	switch m := m.(type) {
 47	case *messageState:
 48		m.messageInfo().lazyUnmarshal(m.pointer(), num)
 49	case *messageReflectWrapper:
 50		m.messageInfo().lazyUnmarshal(m.pointer(), num)
 51	default:
 52		panic(fmt.Sprintf("unsupported wrapper type %T", m))
 53	}
 54}
 55
 56func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
 57	var f *coderFieldInfo
 58	if int(num) < len(mi.denseCoderFields) {
 59		f = mi.denseCoderFields[num]
 60	} else {
 61		f = mi.coderFields[num]
 62	}
 63	if f == nil {
 64		panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
 65	}
 66	lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
 67	start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
 68	if !found && multipleEntries == nil {
 69		panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
 70	}
 71	// The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
 72	// Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
 73	fp := pointerOfValue(reflect.New(f.ft))
 74	if multipleEntries != nil {
 75		for _, entry := range multipleEntries {
 76			mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
 77		}
 78	} else {
 79		mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
 80	}
 81	p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
 82}
 83
 84func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
 85	opts := lazyUnmarshalOptions
 86	opts.flags |= flags
 87	for len(b) > 0 {
 88		// Parse the tag (field number and wire type).
 89		var tag uint64
 90		if b[0] < 0x80 {
 91			tag = uint64(b[0])
 92			b = b[1:]
 93		} else if len(b) >= 2 && b[1] < 128 {
 94			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
 95			b = b[2:]
 96		} else {
 97			var n int
 98			tag, n = protowire.ConsumeVarint(b)
 99			if n < 0 {
100				return errors.New("invalid wire data")
101			}
102			b = b[n:]
103		}
104		var num protowire.Number
105		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
106			return errors.New("invalid wire data")
107		} else {
108			num = protowire.Number(n)
109		}
110		wtyp := protowire.Type(tag & 7)
111		if num == f.num {
112			o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
113			if err == nil {
114				b = b[o.n:]
115				continue
116			}
117			if err != errUnknown {
118				return err
119			}
120		}
121		n := protowire.ConsumeFieldValue(num, wtyp, b)
122		if n < 0 {
123			return errors.New("invalid wire data")
124		}
125		b = b[n:]
126	}
127	return nil
128}
129
130func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
131	fmi := f.validation.mi
132	if fmi == nil {
133		fd := mi.Desc.Fields().ByNumber(f.num)
134		if fd == nil {
135			return out, ValidationUnknown
136		}
137		messageName := fd.Message().FullName()
138		messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
139		if err != nil {
140			return out, ValidationUnknown
141		}
142		var ok bool
143		fmi, ok = messageType.(*MessageInfo)
144		if !ok {
145			return out, ValidationUnknown
146		}
147	}
148	fmi.init()
149	switch f.validation.typ {
150	case validationTypeMessage:
151		if wtyp != protowire.BytesType {
152			return out, ValidationWrongWireType
153		}
154		v, n := protowire.ConsumeBytes(b)
155		if n < 0 {
156			return out, ValidationInvalid
157		}
158		out, st := fmi.validate(v, 0, opts)
159		out.n = n
160		return out, st
161	case validationTypeGroup:
162		if wtyp != protowire.StartGroupType {
163			return out, ValidationWrongWireType
164		}
165		out, st := fmi.validate(b, f.num, opts)
166		return out, st
167	default:
168		return out, ValidationUnknown
169	}
170}
171
172// unmarshalPointerLazy is similar to unmarshalPointerEager, but it
173// specifically handles lazy unmarshalling.  it expects lazyOffset and
174// presenceOffset to both be valid.
175func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
176	initialized := true
177	var requiredMask uint64
178	var lazy **protolazy.XXX_lazyUnmarshalInfo
179	var presence presence
180	var lazyIndex []protolazy.IndexEntry
181	var lastNum protowire.Number
182	outOfOrder := false
183	lazyDecode := false
184	presence = p.Apply(mi.presenceOffset).PresenceInfo()
185	lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
186	if !presence.AnyPresent(mi.presenceSize) {
187		if opts.CanBeLazy() {
188			// If the message contains existing data, we need to merge into it.
189			// Lazy unmarshaling doesn't merge, so only enable it when the
190			// message is empty (has no presence bitmap).
191			lazyDecode = true
192			if *lazy == nil {
193				*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
194			}
195			(*lazy).SetUnmarshalFlags(opts.flags)
196			if !opts.AliasBuffer() {
197				// Make a copy of the buffer for lazy unmarshaling.
198				// Set the AliasBuffer flag so recursive unmarshal
199				// operations reuse the copy.
200				b = append([]byte{}, b...)
201				opts.flags |= piface.UnmarshalAliasBuffer
202			}
203			(*lazy).SetBuffer(b)
204		}
205	}
206	// Track special handling of lazy fields.
207	//
208	// In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
209	// In the event that validation for a field fails, this map tracks handling of the field.
210	type lazyAction uint8
211	const (
212		lazyValidateOnly   lazyAction = iota // validate the field only
213		lazyUnmarshalNow                     // eagerly unmarshal the field
214		lazyUnmarshalLater                   // unmarshal the field after the message is fully processed
215	)
216	var lazyFields map[*coderFieldInfo]lazyAction
217	var exts *map[int32]ExtensionField
218	start := len(b)
219	pos := 0
220	for len(b) > 0 {
221		// Parse the tag (field number and wire type).
222		var tag uint64
223		if b[0] < 0x80 {
224			tag = uint64(b[0])
225			b = b[1:]
226		} else if len(b) >= 2 && b[1] < 128 {
227			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
228			b = b[2:]
229		} else {
230			var n int
231			tag, n = protowire.ConsumeVarint(b)
232			if n < 0 {
233				return out, errDecode
234			}
235			b = b[n:]
236		}
237		var num protowire.Number
238		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
239			return out, errors.New("invalid field number")
240		} else {
241			num = protowire.Number(n)
242		}
243		wtyp := protowire.Type(tag & 7)
244
245		if wtyp == protowire.EndGroupType {
246			if num != groupTag {
247				return out, errors.New("mismatching end group marker")
248			}
249			groupTag = 0
250			break
251		}
252
253		var f *coderFieldInfo
254		if int(num) < len(mi.denseCoderFields) {
255			f = mi.denseCoderFields[num]
256		} else {
257			f = mi.coderFields[num]
258		}
259		var n int
260		err := errUnknown
261		discardUnknown := false
262	Field:
263		switch {
264		case f != nil:
265			if f.funcs.unmarshal == nil {
266				break
267			}
268			if f.isLazy && lazyDecode {
269				switch {
270				case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
271					// Attempt to validate this field and leave it for later lazy unmarshaling.
272					o, valid := mi.skipField(b, f, wtyp, opts)
273					switch valid {
274					case ValidationValid:
275						// Skip over the valid field and continue.
276						err = nil
277						presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
278						requiredMask |= f.validation.requiredBit
279						if !o.initialized {
280							initialized = false
281						}
282						n = o.n
283						break Field
284					case ValidationInvalid:
285						return out, errors.New("invalid proto wire format")
286					case ValidationWrongWireType:
287						break Field
288					case ValidationUnknown:
289						if lazyFields == nil {
290							lazyFields = make(map[*coderFieldInfo]lazyAction)
291						}
292						if presence.Present(f.presenceIndex) {
293							// We were unable to determine if the field is valid or not,
294							// and we've already skipped over at least one instance of this
295							// field. Clear the presence bit (so if we stop decoding early,
296							// we don't leave a partially-initialized field around) and flag
297							// the field for unmarshaling before we return.
298							presence.ClearPresent(f.presenceIndex)
299							lazyFields[f] = lazyUnmarshalLater
300							discardUnknown = true
301							break Field
302						} else {
303							// We were unable to determine if the field is valid or not,
304							// but this is the first time we've seen it. Flag it as needing
305							// eager unmarshaling and fall through to the eager unmarshal case below.
306							lazyFields[f] = lazyUnmarshalNow
307						}
308					}
309				case lazyFields[f] == lazyUnmarshalLater:
310					// This field will be unmarshaled in a separate pass below.
311					// Skip over it here.
312					discardUnknown = true
313					break Field
314				default:
315					// Eagerly unmarshal the field.
316				}
317			}
318			if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
319				if p.Apply(f.offset).AtomicGetPointer().IsNil() {
320					mi.lazyUnmarshal(p, f.num)
321				}
322			}
323			var o unmarshalOutput
324			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
325			n = o.n
326			if err != nil {
327				break
328			}
329			requiredMask |= f.validation.requiredBit
330			if f.funcs.isInit != nil && !o.initialized {
331				initialized = false
332			}
333			if f.presenceIndex != noPresence {
334				presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
335			}
336		default:
337			// Possible extension.
338			if exts == nil && mi.extensionOffset.IsValid() {
339				exts = p.Apply(mi.extensionOffset).Extensions()
340				if *exts == nil {
341					*exts = make(map[int32]ExtensionField)
342				}
343			}
344			if exts == nil {
345				break
346			}
347			var o unmarshalOutput
348			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
349			if err != nil {
350				break
351			}
352			n = o.n
353			if !o.initialized {
354				initialized = false
355			}
356		}
357		if err != nil {
358			if err != errUnknown {
359				return out, err
360			}
361			n = protowire.ConsumeFieldValue(num, wtyp, b)
362			if n < 0 {
363				return out, errDecode
364			}
365			if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
366				u := mi.mutableUnknownBytes(p)
367				*u = protowire.AppendTag(*u, num, wtyp)
368				*u = append(*u, b[:n]...)
369			}
370		}
371		b = b[n:]
372		end := start - len(b)
373		if lazyDecode && f != nil && f.isLazy {
374			if num != lastNum {
375				lazyIndex = append(lazyIndex, protolazy.IndexEntry{
376					FieldNum: uint32(num),
377					Start:    uint32(pos),
378					End:      uint32(end),
379				})
380			} else {
381				i := len(lazyIndex) - 1
382				lazyIndex[i].End = uint32(end)
383				lazyIndex[i].MultipleContiguous = true
384			}
385		}
386		if num < lastNum {
387			outOfOrder = true
388		}
389		pos = end
390		lastNum = num
391	}
392	if groupTag != 0 {
393		return out, errors.New("missing end group marker")
394	}
395	if lazyFields != nil {
396		// Some fields failed validation, and now need to be unmarshaled.
397		for f, action := range lazyFields {
398			if action != lazyUnmarshalLater {
399				continue
400			}
401			initialized = false
402			if *lazy == nil {
403				*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
404			}
405			if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
406				return out, err
407			}
408			presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
409		}
410	}
411	if lazyDecode {
412		if outOfOrder {
413			sort.Slice(lazyIndex, func(i, j int) bool {
414				return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
415					(lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
416						lazyIndex[i].Start < lazyIndex[j].Start)
417			})
418		}
419		if *lazy == nil {
420			*lazy = &protolazy.XXX_lazyUnmarshalInfo{}
421		}
422
423		(*lazy).SetIndex(lazyIndex)
424	}
425	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
426		initialized = false
427	}
428	if initialized {
429		out.initialized = true
430	}
431	out.n = start - len(b)
432	return out, nil
433}