bufferreader.go

  1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5// Helper code for parsing a protocol buffer
  6
  7package protolazy
  8
  9import (
 10	"errors"
 11	"fmt"
 12	"io"
 13
 14	"google.golang.org/protobuf/encoding/protowire"
 15)
 16
 17// BufferReader is a structure encapsulating a protobuf and a current position
 18type BufferReader struct {
 19	Buf []byte
 20	Pos int
 21}
 22
 23// NewBufferReader creates a new BufferRead from a protobuf
 24func NewBufferReader(buf []byte) BufferReader {
 25	return BufferReader{Buf: buf, Pos: 0}
 26}
 27
 28var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
 29var errOverflow = errors.New("proto: integer overflow")
 30
 31func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
 32	i := b.Pos
 33	l := len(b.Buf)
 34
 35	for shift := uint(0); shift < 64; shift += 7 {
 36		if i >= l {
 37			err = io.ErrUnexpectedEOF
 38			return
 39		}
 40		v := b.Buf[i]
 41		i++
 42		x |= (uint64(v) & 0x7F) << shift
 43		if v < 0x80 {
 44			b.Pos = i
 45			return
 46		}
 47	}
 48
 49	// The number is too large to represent in a 64-bit value.
 50	err = errOverflow
 51	return
 52}
 53
 54// decodeVarint decodes a varint at the current position
 55func (b *BufferReader) DecodeVarint() (x uint64, err error) {
 56	i := b.Pos
 57	buf := b.Buf
 58
 59	if i >= len(buf) {
 60		return 0, io.ErrUnexpectedEOF
 61	} else if buf[i] < 0x80 {
 62		b.Pos++
 63		return uint64(buf[i]), nil
 64	} else if len(buf)-i < 10 {
 65		return b.DecodeVarintSlow()
 66	}
 67
 68	var v uint64
 69	// we already checked the first byte
 70	x = uint64(buf[i]) & 127
 71	i++
 72
 73	v = uint64(buf[i])
 74	i++
 75	x |= (v & 127) << 7
 76	if v < 128 {
 77		goto done
 78	}
 79
 80	v = uint64(buf[i])
 81	i++
 82	x |= (v & 127) << 14
 83	if v < 128 {
 84		goto done
 85	}
 86
 87	v = uint64(buf[i])
 88	i++
 89	x |= (v & 127) << 21
 90	if v < 128 {
 91		goto done
 92	}
 93
 94	v = uint64(buf[i])
 95	i++
 96	x |= (v & 127) << 28
 97	if v < 128 {
 98		goto done
 99	}
100
101	v = uint64(buf[i])
102	i++
103	x |= (v & 127) << 35
104	if v < 128 {
105		goto done
106	}
107
108	v = uint64(buf[i])
109	i++
110	x |= (v & 127) << 42
111	if v < 128 {
112		goto done
113	}
114
115	v = uint64(buf[i])
116	i++
117	x |= (v & 127) << 49
118	if v < 128 {
119		goto done
120	}
121
122	v = uint64(buf[i])
123	i++
124	x |= (v & 127) << 56
125	if v < 128 {
126		goto done
127	}
128
129	v = uint64(buf[i])
130	i++
131	x |= (v & 127) << 63
132	if v < 128 {
133		goto done
134	}
135
136	return 0, errOverflow
137
138done:
139	b.Pos = i
140	return
141}
142
143// decodeVarint32 decodes a varint32 at the current position
144func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
145	i := b.Pos
146	buf := b.Buf
147
148	if i >= len(buf) {
149		return 0, io.ErrUnexpectedEOF
150	} else if buf[i] < 0x80 {
151		b.Pos++
152		return uint32(buf[i]), nil
153	} else if len(buf)-i < 5 {
154		v, err := b.DecodeVarintSlow()
155		return uint32(v), err
156	}
157
158	var v uint32
159	// we already checked the first byte
160	x = uint32(buf[i]) & 127
161	i++
162
163	v = uint32(buf[i])
164	i++
165	x |= (v & 127) << 7
166	if v < 128 {
167		goto done
168	}
169
170	v = uint32(buf[i])
171	i++
172	x |= (v & 127) << 14
173	if v < 128 {
174		goto done
175	}
176
177	v = uint32(buf[i])
178	i++
179	x |= (v & 127) << 21
180	if v < 128 {
181		goto done
182	}
183
184	v = uint32(buf[i])
185	i++
186	x |= (v & 127) << 28
187	if v < 128 {
188		goto done
189	}
190
191	return 0, errOverflow
192
193done:
194	b.Pos = i
195	return
196}
197
198// skipValue skips a value in the protobuf, based on the specified tag
199func (b *BufferReader) SkipValue(tag uint32) (err error) {
200	wireType := tag & 0x7
201	switch protowire.Type(wireType) {
202	case protowire.VarintType:
203		err = b.SkipVarint()
204	case protowire.Fixed64Type:
205		err = b.SkipFixed64()
206	case protowire.BytesType:
207		var n uint32
208		n, err = b.DecodeVarint32()
209		if err == nil {
210			err = b.Skip(int(n))
211		}
212	case protowire.StartGroupType:
213		err = b.SkipGroup(tag)
214	case protowire.Fixed32Type:
215		err = b.SkipFixed32()
216	default:
217		err = fmt.Errorf("Unexpected wire type (%d)", wireType)
218	}
219	return
220}
221
222// skipGroup skips a group with the specified tag.  It executes efficiently using a tag stack
223func (b *BufferReader) SkipGroup(tag uint32) (err error) {
224	tagStack := make([]uint32, 0, 16)
225	tagStack = append(tagStack, tag)
226	var n uint32
227	for len(tagStack) > 0 {
228		tag, err = b.DecodeVarint32()
229		if err != nil {
230			return err
231		}
232		switch protowire.Type(tag & 0x7) {
233		case protowire.VarintType:
234			err = b.SkipVarint()
235		case protowire.Fixed64Type:
236			err = b.Skip(8)
237		case protowire.BytesType:
238			n, err = b.DecodeVarint32()
239			if err == nil {
240				err = b.Skip(int(n))
241			}
242		case protowire.StartGroupType:
243			tagStack = append(tagStack, tag)
244		case protowire.Fixed32Type:
245			err = b.SkipFixed32()
246		case protowire.EndGroupType:
247			if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
248				tagStack = tagStack[:len(tagStack)-1]
249			} else {
250				err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
251					protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
252			}
253		}
254		if err != nil {
255			return err
256		}
257	}
258	return nil
259}
260
261// skipVarint effiently skips a varint
262func (b *BufferReader) SkipVarint() (err error) {
263	i := b.Pos
264
265	if len(b.Buf)-i < 10 {
266		// Use DecodeVarintSlow() to check for buffer overflow, but ignore result
267		if _, err := b.DecodeVarintSlow(); err != nil {
268			return err
269		}
270		return nil
271	}
272
273	if b.Buf[i] < 0x80 {
274		goto out
275	}
276	i++
277
278	if b.Buf[i] < 0x80 {
279		goto out
280	}
281	i++
282
283	if b.Buf[i] < 0x80 {
284		goto out
285	}
286	i++
287
288	if b.Buf[i] < 0x80 {
289		goto out
290	}
291	i++
292
293	if b.Buf[i] < 0x80 {
294		goto out
295	}
296	i++
297
298	if b.Buf[i] < 0x80 {
299		goto out
300	}
301	i++
302
303	if b.Buf[i] < 0x80 {
304		goto out
305	}
306	i++
307
308	if b.Buf[i] < 0x80 {
309		goto out
310	}
311	i++
312
313	if b.Buf[i] < 0x80 {
314		goto out
315	}
316	i++
317
318	if b.Buf[i] < 0x80 {
319		goto out
320	}
321	return errOverflow
322
323out:
324	b.Pos = i + 1
325	return nil
326}
327
328// skip skips the specified number of bytes
329func (b *BufferReader) Skip(n int) (err error) {
330	if len(b.Buf) < b.Pos+n {
331		return io.ErrUnexpectedEOF
332	}
333	b.Pos += n
334	return
335}
336
337// skipFixed64 skips a fixed64
338func (b *BufferReader) SkipFixed64() (err error) {
339	return b.Skip(8)
340}
341
342// skipFixed32 skips a fixed32
343func (b *BufferReader) SkipFixed32() (err error) {
344	return b.Skip(4)
345}
346
347// skipBytes skips a set of bytes
348func (b *BufferReader) SkipBytes() (err error) {
349	n, err := b.DecodeVarint32()
350	if err != nil {
351		return err
352	}
353	return b.Skip(int(n))
354}
355
356// Done returns whether we are at the end of the protobuf
357func (b *BufferReader) Done() bool {
358	return b.Pos == len(b.Buf)
359}
360
361// Remaining returns how many bytes remain
362func (b *BufferReader) Remaining() int {
363	return len(b.Buf) - b.Pos
364}