1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 "google.golang.org/protobuf/encoding/protowire"
12 "google.golang.org/protobuf/internal/errors"
13 "google.golang.org/protobuf/reflect/protoreflect"
14)
15
16func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
17 mi := getMessageInfo(ft)
18 if mi == nil {
19 panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft))
20 }
21 switch fd.Kind() {
22 case protoreflect.MessageKind:
23 return mi, pointerCoderFuncs{
24 size: sizeOpaqueMessage,
25 marshal: appendOpaqueMessage,
26 unmarshal: consumeOpaqueMessage,
27 isInit: isInitOpaqueMessage,
28 merge: mergeOpaqueMessage,
29 }
30 case protoreflect.GroupKind:
31 return mi, pointerCoderFuncs{
32 size: sizeOpaqueGroup,
33 marshal: appendOpaqueGroup,
34 unmarshal: consumeOpaqueGroup,
35 isInit: isInitOpaqueMessage,
36 merge: mergeOpaqueMessage,
37 }
38 }
39 panic("unexpected field kind")
40}
41
42func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
43 return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize
44}
45
46func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
47 mp := p.AtomicGetPointer()
48 calculatedSize := f.mi.sizePointer(mp, opts)
49 b = protowire.AppendVarint(b, f.wiretag)
50 b = protowire.AppendVarint(b, uint64(calculatedSize))
51 before := len(b)
52 b, err := f.mi.marshalAppendPointer(b, mp, opts)
53 if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
54 return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
55 }
56 return b, err
57}
58
59func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
60 if wtyp != protowire.BytesType {
61 return out, errUnknown
62 }
63 v, n := protowire.ConsumeBytes(b)
64 if n < 0 {
65 return out, errDecode
66 }
67 mp := p.AtomicGetPointer()
68 if mp.IsNil() {
69 mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
70 }
71 o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
72 if err != nil {
73 return out, err
74 }
75 out.n = n
76 out.initialized = o.initialized
77 return out, nil
78}
79
80func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error {
81 mp := p.AtomicGetPointer()
82 if mp.IsNil() {
83 return nil
84 }
85 return f.mi.checkInitializedPointer(mp)
86}
87
88func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
89 dstmp := dst.AtomicGetPointer()
90 if dstmp.IsNil() {
91 dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
92 }
93 f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts)
94}
95
96func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
97 return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts)
98}
99
100func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
101 b = protowire.AppendVarint(b, f.wiretag) // start group
102 b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts)
103 b = protowire.AppendVarint(b, f.wiretag+1) // end group
104 return b, err
105}
106
107func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
108 if wtyp != protowire.StartGroupType {
109 return out, errUnknown
110 }
111 mp := p.AtomicGetPointer()
112 if mp.IsNil() {
113 mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
114 }
115 o, e := f.mi.unmarshalPointer(b, mp, f.num, opts)
116 return o, e
117}
118
119func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
120 if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
121 panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft))
122 }
123 mt := ft.Elem().Elem() // *[]*T -> *T
124 mi := getMessageInfo(mt)
125 if mi == nil {
126 panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt))
127 }
128 switch fd.Kind() {
129 case protoreflect.MessageKind:
130 return mi, pointerCoderFuncs{
131 size: sizeOpaqueMessageSlice,
132 marshal: appendOpaqueMessageSlice,
133 unmarshal: consumeOpaqueMessageSlice,
134 isInit: isInitOpaqueMessageSlice,
135 merge: mergeOpaqueMessageSlice,
136 }
137 case protoreflect.GroupKind:
138 return mi, pointerCoderFuncs{
139 size: sizeOpaqueGroupSlice,
140 marshal: appendOpaqueGroupSlice,
141 unmarshal: consumeOpaqueGroupSlice,
142 isInit: isInitOpaqueMessageSlice,
143 merge: mergeOpaqueMessageSlice,
144 }
145 }
146 panic("unexpected field kind")
147}
148
149func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
150 s := p.AtomicGetPointer().PointerSlice()
151 n := 0
152 for _, v := range s {
153 n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
154 }
155 return n
156}
157
158func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
159 s := p.AtomicGetPointer().PointerSlice()
160 var err error
161 for _, v := range s {
162 b = protowire.AppendVarint(b, f.wiretag)
163 siz := f.mi.sizePointer(v, opts)
164 b = protowire.AppendVarint(b, uint64(siz))
165 before := len(b)
166 b, err = f.mi.marshalAppendPointer(b, v, opts)
167 if err != nil {
168 return b, err
169 }
170 if measuredSize := len(b) - before; siz != measuredSize {
171 return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
172 }
173 }
174 return b, nil
175}
176
177func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
178 if wtyp != protowire.BytesType {
179 return out, errUnknown
180 }
181 v, n := protowire.ConsumeBytes(b)
182 if n < 0 {
183 return out, errDecode
184 }
185 mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
186 o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
187 if err != nil {
188 return out, err
189 }
190 sp := p.AtomicGetPointer()
191 if sp.IsNil() {
192 sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
193 }
194 sp.AppendPointerSlice(mp)
195 out.n = n
196 out.initialized = o.initialized
197 return out, nil
198}
199
200func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error {
201 sp := p.AtomicGetPointer()
202 if sp.IsNil() {
203 return nil
204 }
205 s := sp.PointerSlice()
206 for _, v := range s {
207 if err := f.mi.checkInitializedPointer(v); err != nil {
208 return err
209 }
210 }
211 return nil
212}
213
214func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
215 ds := dst.AtomicGetPointer()
216 if ds.IsNil() {
217 ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
218 }
219 for _, sp := range src.AtomicGetPointer().PointerSlice() {
220 dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
221 f.mi.mergePointer(dm, sp, opts)
222 ds.AppendPointerSlice(dm)
223 }
224}
225
226func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
227 s := p.AtomicGetPointer().PointerSlice()
228 n := 0
229 for _, v := range s {
230 n += 2*f.tagsize + f.mi.sizePointer(v, opts)
231 }
232 return n
233}
234
235func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
236 s := p.AtomicGetPointer().PointerSlice()
237 var err error
238 for _, v := range s {
239 b = protowire.AppendVarint(b, f.wiretag) // start group
240 b, err = f.mi.marshalAppendPointer(b, v, opts)
241 if err != nil {
242 return b, err
243 }
244 b = protowire.AppendVarint(b, f.wiretag+1) // end group
245 }
246 return b, nil
247}
248
249func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
250 if wtyp != protowire.StartGroupType {
251 return out, errUnknown
252 }
253 mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
254 out, err = f.mi.unmarshalPointer(b, mp, f.num, opts)
255 if err != nil {
256 return out, err
257 }
258 sp := p.AtomicGetPointer()
259 if sp.IsNil() {
260 sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
261 }
262 sp.AppendPointerSlice(mp)
263 return out, err
264}