codec_field_opaque.go

  1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"fmt"
  9	"reflect"
 10
 11	"google.golang.org/protobuf/encoding/protowire"
 12	"google.golang.org/protobuf/internal/errors"
 13	"google.golang.org/protobuf/reflect/protoreflect"
 14)
 15
 16func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
 17	mi := getMessageInfo(ft)
 18	if mi == nil {
 19		panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft))
 20	}
 21	switch fd.Kind() {
 22	case protoreflect.MessageKind:
 23		return mi, pointerCoderFuncs{
 24			size:      sizeOpaqueMessage,
 25			marshal:   appendOpaqueMessage,
 26			unmarshal: consumeOpaqueMessage,
 27			isInit:    isInitOpaqueMessage,
 28			merge:     mergeOpaqueMessage,
 29		}
 30	case protoreflect.GroupKind:
 31		return mi, pointerCoderFuncs{
 32			size:      sizeOpaqueGroup,
 33			marshal:   appendOpaqueGroup,
 34			unmarshal: consumeOpaqueGroup,
 35			isInit:    isInitOpaqueMessage,
 36			merge:     mergeOpaqueMessage,
 37		}
 38	}
 39	panic("unexpected field kind")
 40}
 41
 42func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
 43	return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize
 44}
 45
 46func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 47	mp := p.AtomicGetPointer()
 48	calculatedSize := f.mi.sizePointer(mp, opts)
 49	b = protowire.AppendVarint(b, f.wiretag)
 50	b = protowire.AppendVarint(b, uint64(calculatedSize))
 51	before := len(b)
 52	b, err := f.mi.marshalAppendPointer(b, mp, opts)
 53	if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
 54		return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
 55	}
 56	return b, err
 57}
 58
 59func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 60	if wtyp != protowire.BytesType {
 61		return out, errUnknown
 62	}
 63	v, n := protowire.ConsumeBytes(b)
 64	if n < 0 {
 65		return out, errDecode
 66	}
 67	mp := p.AtomicGetPointer()
 68	if mp.IsNil() {
 69		mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
 70	}
 71	o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
 72	if err != nil {
 73		return out, err
 74	}
 75	out.n = n
 76	out.initialized = o.initialized
 77	return out, nil
 78}
 79
 80func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error {
 81	mp := p.AtomicGetPointer()
 82	if mp.IsNil() {
 83		return nil
 84	}
 85	return f.mi.checkInitializedPointer(mp)
 86}
 87
 88func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
 89	dstmp := dst.AtomicGetPointer()
 90	if dstmp.IsNil() {
 91		dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
 92	}
 93	f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts)
 94}
 95
 96func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
 97	return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts)
 98}
 99
100func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
101	b = protowire.AppendVarint(b, f.wiretag) // start group
102	b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts)
103	b = protowire.AppendVarint(b, f.wiretag+1) // end group
104	return b, err
105}
106
107func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
108	if wtyp != protowire.StartGroupType {
109		return out, errUnknown
110	}
111	mp := p.AtomicGetPointer()
112	if mp.IsNil() {
113		mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
114	}
115	o, e := f.mi.unmarshalPointer(b, mp, f.num, opts)
116	return o, e
117}
118
119func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
120	if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
121		panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft))
122	}
123	mt := ft.Elem().Elem() // *[]*T -> *T
124	mi := getMessageInfo(mt)
125	if mi == nil {
126		panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt))
127	}
128	switch fd.Kind() {
129	case protoreflect.MessageKind:
130		return mi, pointerCoderFuncs{
131			size:      sizeOpaqueMessageSlice,
132			marshal:   appendOpaqueMessageSlice,
133			unmarshal: consumeOpaqueMessageSlice,
134			isInit:    isInitOpaqueMessageSlice,
135			merge:     mergeOpaqueMessageSlice,
136		}
137	case protoreflect.GroupKind:
138		return mi, pointerCoderFuncs{
139			size:      sizeOpaqueGroupSlice,
140			marshal:   appendOpaqueGroupSlice,
141			unmarshal: consumeOpaqueGroupSlice,
142			isInit:    isInitOpaqueMessageSlice,
143			merge:     mergeOpaqueMessageSlice,
144		}
145	}
146	panic("unexpected field kind")
147}
148
149func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
150	s := p.AtomicGetPointer().PointerSlice()
151	n := 0
152	for _, v := range s {
153		n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
154	}
155	return n
156}
157
158func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
159	s := p.AtomicGetPointer().PointerSlice()
160	var err error
161	for _, v := range s {
162		b = protowire.AppendVarint(b, f.wiretag)
163		siz := f.mi.sizePointer(v, opts)
164		b = protowire.AppendVarint(b, uint64(siz))
165		before := len(b)
166		b, err = f.mi.marshalAppendPointer(b, v, opts)
167		if err != nil {
168			return b, err
169		}
170		if measuredSize := len(b) - before; siz != measuredSize {
171			return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
172		}
173	}
174	return b, nil
175}
176
177func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
178	if wtyp != protowire.BytesType {
179		return out, errUnknown
180	}
181	v, n := protowire.ConsumeBytes(b)
182	if n < 0 {
183		return out, errDecode
184	}
185	mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
186	o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
187	if err != nil {
188		return out, err
189	}
190	sp := p.AtomicGetPointer()
191	if sp.IsNil() {
192		sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
193	}
194	sp.AppendPointerSlice(mp)
195	out.n = n
196	out.initialized = o.initialized
197	return out, nil
198}
199
200func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error {
201	sp := p.AtomicGetPointer()
202	if sp.IsNil() {
203		return nil
204	}
205	s := sp.PointerSlice()
206	for _, v := range s {
207		if err := f.mi.checkInitializedPointer(v); err != nil {
208			return err
209		}
210	}
211	return nil
212}
213
214func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
215	ds := dst.AtomicGetPointer()
216	if ds.IsNil() {
217		ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
218	}
219	for _, sp := range src.AtomicGetPointer().PointerSlice() {
220		dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
221		f.mi.mergePointer(dm, sp, opts)
222		ds.AppendPointerSlice(dm)
223	}
224}
225
226func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
227	s := p.AtomicGetPointer().PointerSlice()
228	n := 0
229	for _, v := range s {
230		n += 2*f.tagsize + f.mi.sizePointer(v, opts)
231	}
232	return n
233}
234
235func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
236	s := p.AtomicGetPointer().PointerSlice()
237	var err error
238	for _, v := range s {
239		b = protowire.AppendVarint(b, f.wiretag) // start group
240		b, err = f.mi.marshalAppendPointer(b, v, opts)
241		if err != nil {
242			return b, err
243		}
244		b = protowire.AppendVarint(b, f.wiretag+1) // end group
245	}
246	return b, nil
247}
248
249func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
250	if wtyp != protowire.StartGroupType {
251		return out, errUnknown
252	}
253	mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
254	out, err = f.mi.unmarshalPointer(b, mp, f.num, opts)
255	if err != nil {
256		return out, err
257	}
258	sp := p.AtomicGetPointer()
259	if sp.IsNil() {
260		sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
261	}
262	sp.AppendPointerSlice(mp)
263	return out, err
264}