lazy.go

  1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5// Package protolazy contains internal data structures for lazy message decoding.
  6package protolazy
  7
  8import (
  9	"fmt"
 10	"sort"
 11
 12	"google.golang.org/protobuf/encoding/protowire"
 13	piface "google.golang.org/protobuf/runtime/protoiface"
 14)
 15
 16// IndexEntry is the structure for an index of the fields in a message of a
 17// proto (not descending to sub-messages)
 18type IndexEntry struct {
 19	FieldNum uint32
 20	// first byte of this tag/field
 21	Start uint32
 22	// first byte after a contiguous sequence of bytes for this tag/field, which could
 23	// include a single encoding of the field, or multiple encodings for the field
 24	End uint32
 25	// True if this protobuf segment includes multiple encodings of the field
 26	MultipleContiguous bool
 27}
 28
 29// XXX_lazyUnmarshalInfo has information about a particular lazily decoded message
 30//
 31// Deprecated: Do not use. This will be deleted in the near future.
 32type XXX_lazyUnmarshalInfo struct {
 33	// Index of fields and their positions in the protobuf for this
 34	// message.  Make index be a pointer to a slice so it can be updated
 35	// atomically.  The index pointer is only set once (lazily when/if
 36	// the index is first needed), and must always be SET and LOADED
 37	// ATOMICALLY.
 38	index *[]IndexEntry
 39	// The protobuf associated with this lazily decoded message.  It is
 40	// only set during proto.Unmarshal().  It doesn't need to be set and
 41	// loaded atomically, since any simultaneous set (Unmarshal) and read
 42	// (during a get) would already be a race in the app code.
 43	Protobuf []byte
 44	// The flags present when Unmarshal was originally called for this particular message
 45	unmarshalFlags piface.UnmarshalInputFlags
 46}
 47
 48// The Buffer and SetBuffer methods let v2/internal/impl interact with
 49// XXX_lazyUnmarshalInfo via an interface, to avoid an import cycle.
 50
 51// Buffer returns the lazy unmarshal buffer.
 52//
 53// Deprecated: Do not use. This will be deleted in the near future.
 54func (lazy *XXX_lazyUnmarshalInfo) Buffer() []byte {
 55	return lazy.Protobuf
 56}
 57
 58// SetBuffer sets the lazy unmarshal buffer.
 59//
 60// Deprecated: Do not use. This will be deleted in the near future.
 61func (lazy *XXX_lazyUnmarshalInfo) SetBuffer(b []byte) {
 62	lazy.Protobuf = b
 63}
 64
 65// SetUnmarshalFlags is called to set a copy of the original unmarshalInputFlags.
 66// The flags should reflect how Unmarshal was called.
 67func (lazy *XXX_lazyUnmarshalInfo) SetUnmarshalFlags(f piface.UnmarshalInputFlags) {
 68	lazy.unmarshalFlags = f
 69}
 70
 71// UnmarshalFlags returns the original unmarshalInputFlags.
 72func (lazy *XXX_lazyUnmarshalInfo) UnmarshalFlags() piface.UnmarshalInputFlags {
 73	return lazy.unmarshalFlags
 74}
 75
 76// AllowedPartial returns true if the user originally unmarshalled this message with
 77// AllowPartial set to true
 78func (lazy *XXX_lazyUnmarshalInfo) AllowedPartial() bool {
 79	return (lazy.unmarshalFlags & piface.UnmarshalCheckRequired) == 0
 80}
 81
 82func protoFieldNumber(tag uint32) uint32 {
 83	return tag >> 3
 84}
 85
 86// buildIndex builds an index of the specified protobuf, return the index
 87// array and an error.
 88func buildIndex(buf []byte) ([]IndexEntry, error) {
 89	index := make([]IndexEntry, 0, 16)
 90	var lastProtoFieldNum uint32
 91	var outOfOrder bool
 92
 93	var r BufferReader = NewBufferReader(buf)
 94
 95	for !r.Done() {
 96		var tag uint32
 97		var err error
 98		var curPos = r.Pos
 99		// INLINED: tag, err = r.DecodeVarint32()
100		{
101			i := r.Pos
102			buf := r.Buf
103
104			if i >= len(buf) {
105				return nil, errOutOfBounds
106			} else if buf[i] < 0x80 {
107				r.Pos++
108				tag = uint32(buf[i])
109			} else if r.Remaining() < 5 {
110				var v uint64
111				v, err = r.DecodeVarintSlow()
112				tag = uint32(v)
113			} else {
114				var v uint32
115				// we already checked the first byte
116				tag = uint32(buf[i]) & 127
117				i++
118
119				v = uint32(buf[i])
120				i++
121				tag |= (v & 127) << 7
122				if v < 128 {
123					goto done
124				}
125
126				v = uint32(buf[i])
127				i++
128				tag |= (v & 127) << 14
129				if v < 128 {
130					goto done
131				}
132
133				v = uint32(buf[i])
134				i++
135				tag |= (v & 127) << 21
136				if v < 128 {
137					goto done
138				}
139
140				v = uint32(buf[i])
141				i++
142				tag |= (v & 127) << 28
143				if v < 128 {
144					goto done
145				}
146
147				return nil, errOutOfBounds
148
149			done:
150				r.Pos = i
151			}
152		}
153		// DONE: tag, err = r.DecodeVarint32()
154
155		fieldNum := protoFieldNumber(tag)
156		if fieldNum < lastProtoFieldNum {
157			outOfOrder = true
158		}
159
160		// Skip the current value -- will skip over an entire group as well.
161		// INLINED: err = r.SkipValue(tag)
162		wireType := tag & 0x7
163		switch protowire.Type(wireType) {
164		case protowire.VarintType:
165			// INLINED: err = r.SkipVarint()
166			i := r.Pos
167
168			if len(r.Buf)-i < 10 {
169				// Use DecodeVarintSlow() to skip while
170				// checking for buffer overflow, but ignore result
171				_, err = r.DecodeVarintSlow()
172				goto out2
173			}
174			if r.Buf[i] < 0x80 {
175				goto out
176			}
177			i++
178
179			if r.Buf[i] < 0x80 {
180				goto out
181			}
182			i++
183
184			if r.Buf[i] < 0x80 {
185				goto out
186			}
187			i++
188
189			if r.Buf[i] < 0x80 {
190				goto out
191			}
192			i++
193
194			if r.Buf[i] < 0x80 {
195				goto out
196			}
197			i++
198
199			if r.Buf[i] < 0x80 {
200				goto out
201			}
202			i++
203
204			if r.Buf[i] < 0x80 {
205				goto out
206			}
207			i++
208
209			if r.Buf[i] < 0x80 {
210				goto out
211			}
212			i++
213
214			if r.Buf[i] < 0x80 {
215				goto out
216			}
217			i++
218
219			if r.Buf[i] < 0x80 {
220				goto out
221			}
222			return nil, errOverflow
223		out:
224			r.Pos = i + 1
225			// DONE: err = r.SkipVarint()
226		case protowire.Fixed64Type:
227			err = r.SkipFixed64()
228		case protowire.BytesType:
229			var n uint32
230			n, err = r.DecodeVarint32()
231			if err == nil {
232				err = r.Skip(int(n))
233			}
234		case protowire.StartGroupType:
235			err = r.SkipGroup(tag)
236		case protowire.Fixed32Type:
237			err = r.SkipFixed32()
238		default:
239			err = fmt.Errorf("Unexpected wire type (%d)", wireType)
240		}
241		// DONE: err = r.SkipValue(tag)
242
243	out2:
244		if err != nil {
245			return nil, err
246		}
247		if fieldNum != lastProtoFieldNum {
248			index = append(index, IndexEntry{FieldNum: fieldNum,
249				Start: uint32(curPos),
250				End:   uint32(r.Pos)},
251			)
252		} else {
253			index[len(index)-1].End = uint32(r.Pos)
254			index[len(index)-1].MultipleContiguous = true
255		}
256		lastProtoFieldNum = fieldNum
257	}
258	if outOfOrder {
259		sort.Slice(index, func(i, j int) bool {
260			return index[i].FieldNum < index[j].FieldNum ||
261				(index[i].FieldNum == index[j].FieldNum &&
262					index[i].Start < index[j].Start)
263		})
264	}
265	return index, nil
266}
267
268func (lazy *XXX_lazyUnmarshalInfo) SizeField(num uint32) (size int) {
269	start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
270	if multipleEntries != nil {
271		for _, entry := range multipleEntries {
272			size += int(entry.End - entry.Start)
273		}
274		return size
275	}
276	if !found {
277		return 0
278	}
279	return int(end - start)
280}
281
282func (lazy *XXX_lazyUnmarshalInfo) AppendField(b []byte, num uint32) ([]byte, bool) {
283	start, end, found, _, multipleEntries := lazy.FindFieldInProto(num)
284	if multipleEntries != nil {
285		for _, entry := range multipleEntries {
286			b = append(b, lazy.Protobuf[entry.Start:entry.End]...)
287		}
288		return b, true
289	}
290	if !found {
291		return nil, false
292	}
293	b = append(b, lazy.Protobuf[start:end]...)
294	return b, true
295}
296
297func (lazy *XXX_lazyUnmarshalInfo) SetIndex(index []IndexEntry) {
298	atomicStoreIndex(&lazy.index, &index)
299}
300
301// FindFieldInProto looks for field fieldNum in lazyUnmarshalInfo information
302// (including protobuf), returns startOffset/endOffset/found.
303func (lazy *XXX_lazyUnmarshalInfo) FindFieldInProto(fieldNum uint32) (start, end uint32, found, multipleContiguous bool, multipleEntries []IndexEntry) {
304	if lazy.Protobuf == nil {
305		// There is no backing protobuf for this message -- it was made from a builder
306		return 0, 0, false, false, nil
307	}
308	index := atomicLoadIndex(&lazy.index)
309	if index == nil {
310		r, err := buildIndex(lazy.Protobuf)
311		if err != nil {
312			panic(fmt.Sprintf("findFieldInfo: error building index when looking for field %d: %v", fieldNum, err))
313		}
314		// lazy.index is a pointer to the slice returned by BuildIndex
315		index = &r
316		atomicStoreIndex(&lazy.index, index)
317	}
318	return lookupField(index, fieldNum)
319}
320
321// lookupField returns the offset at which the indicated field starts using
322// the index, offset immediately after field ends (including all instances of
323// a repeated field), and bools indicating if field was found and if there
324// are multiple encodings of the field in the byte range.
325//
326// To hande the uncommon case where there are repeated encodings for the same
327// field which are not consecutive in the protobuf (so we need to returns
328// multiple start/end offsets), we also return a slice multipleEntries.  If
329// multipleEntries is non-nil, then multiple entries were found, and the
330// values in the slice should be used, rather than start/end/found.
331func lookupField(indexp *[]IndexEntry, fieldNum uint32) (start, end uint32, found bool, multipleContiguous bool, multipleEntries []IndexEntry) {
332	// The pointer indexp to the index was already loaded atomically.
333	// The slice is uniquely associated with the pointer, so it doesn't
334	// need to be loaded atomically.
335	index := *indexp
336	for i, entry := range index {
337		if fieldNum == entry.FieldNum {
338			if i < len(index)-1 && entry.FieldNum == index[i+1].FieldNum {
339				// Handle the uncommon case where there are
340				// repeated entries for the same field which
341				// are not contiguous in the protobuf.
342				multiple := make([]IndexEntry, 1, 2)
343				multiple[0] = IndexEntry{fieldNum, entry.Start, entry.End, entry.MultipleContiguous}
344				i++
345				for i < len(index) && index[i].FieldNum == fieldNum {
346					multiple = append(multiple, IndexEntry{fieldNum, index[i].Start, index[i].End, index[i].MultipleContiguous})
347					i++
348				}
349				return 0, 0, false, false, multiple
350
351			}
352			return entry.Start, entry.End, true, entry.MultipleContiguous, nil
353		}
354		if fieldNum < entry.FieldNum {
355			return 0, 0, false, false, nil
356		}
357	}
358	return 0, 0, false, false, nil
359}