message_opaque.go

  1// Copyright 2024 The Go Authors. All rights reserved.
  2// Use of this source code is governed by a BSD-style
  3// license that can be found in the LICENSE file.
  4
  5package impl
  6
  7import (
  8	"fmt"
  9	"math"
 10	"reflect"
 11	"strings"
 12	"sync/atomic"
 13
 14	"google.golang.org/protobuf/reflect/protoreflect"
 15)
 16
 17type opaqueStructInfo struct {
 18	structInfo
 19}
 20
 21// isOpaque determines whether a protobuf message type is on the Opaque API.  It
 22// checks whether the type is a Go struct that protoc-gen-go would generate.
 23//
 24// This function only detects newly generated messages from the v2
 25// implementation of protoc-gen-go. It is unable to classify generated messages
 26// that are too old or those that are generated by a different generator
 27// such as protoc-gen-gogo.
 28func isOpaque(t reflect.Type) bool {
 29	// The current detection mechanism is to simply check the first field
 30	// for a struct tag with the "protogen" key.
 31	if t.Kind() == reflect.Struct && t.NumField() > 0 {
 32		pgt := t.Field(0).Tag.Get("protogen")
 33		return strings.HasPrefix(pgt, "opaque.")
 34	}
 35	return false
 36}
 37
 38func opaqueInitHook(mi *MessageInfo) bool {
 39	mt := mi.GoReflectType.Elem()
 40	si := opaqueStructInfo{
 41		structInfo: mi.makeStructInfo(mt),
 42	}
 43
 44	if !isOpaque(mt) {
 45		return false
 46	}
 47
 48	defer atomic.StoreUint32(&mi.initDone, 1)
 49
 50	mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
 51	fds := mi.Desc.Fields()
 52	for i := 0; i < fds.Len(); i++ {
 53		fd := fds.Get(i)
 54		fs := si.fieldsByNumber[fd.Number()]
 55		var fi fieldInfo
 56		usePresence, _ := usePresenceForField(si, fd)
 57
 58		switch {
 59		case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
 60			// Oneofs are no different for opaque.
 61			fi = fieldInfoForOneof(fd, si.oneofsByName[fd.ContainingOneof().Name()], mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
 62		case fd.IsMap():
 63			fi = mi.fieldInfoForMapOpaque(si, fd, fs)
 64		case fd.IsList() && fd.Message() == nil && usePresence:
 65			fi = mi.fieldInfoForScalarListOpaque(si, fd, fs)
 66		case fd.IsList() && fd.Message() == nil:
 67			// Proto3 lists without presence can use same access methods as open
 68			fi = fieldInfoForList(fd, fs, mi.Exporter)
 69		case fd.IsList() && usePresence:
 70			fi = mi.fieldInfoForMessageListOpaque(si, fd, fs)
 71		case fd.IsList():
 72			// Proto3 opaque messages that does not need presence bitmap.
 73			// Different representation than open struct, but same logic
 74			fi = mi.fieldInfoForMessageListOpaqueNoPresence(si, fd, fs)
 75		case fd.Message() != nil && usePresence:
 76			fi = mi.fieldInfoForMessageOpaque(si, fd, fs)
 77		case fd.Message() != nil:
 78			// Proto3 messages without presence can use same access methods as open
 79			fi = fieldInfoForMessage(fd, fs, mi.Exporter)
 80		default:
 81			fi = mi.fieldInfoForScalarOpaque(si, fd, fs)
 82		}
 83		mi.fields[fd.Number()] = &fi
 84	}
 85	mi.oneofs = map[protoreflect.Name]*oneofInfo{}
 86	for i := 0; i < mi.Desc.Oneofs().Len(); i++ {
 87		od := mi.Desc.Oneofs().Get(i)
 88		mi.oneofs[od.Name()] = makeOneofInfoOpaque(mi, od, si.structInfo, mi.Exporter)
 89	}
 90
 91	mi.denseFields = make([]*fieldInfo, fds.Len()*2)
 92	for i := 0; i < fds.Len(); i++ {
 93		if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
 94			mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
 95		}
 96	}
 97
 98	for i := 0; i < fds.Len(); {
 99		fd := fds.Get(i)
100		if od := fd.ContainingOneof(); od != nil && !fd.ContainingOneof().IsSynthetic() {
101			mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
102			i += od.Fields().Len()
103		} else {
104			mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
105			i++
106		}
107	}
108
109	mi.makeExtensionFieldsFunc(mt, si.structInfo)
110	mi.makeUnknownFieldsFunc(mt, si.structInfo)
111	mi.makeOpaqueCoderMethods(mt, si)
112	mi.makeFieldTypes(si.structInfo)
113
114	return true
115}
116
117func makeOneofInfoOpaque(mi *MessageInfo, od protoreflect.OneofDescriptor, si structInfo, x exporter) *oneofInfo {
118	oi := &oneofInfo{oneofDesc: od}
119	if od.IsSynthetic() {
120		fd := od.Fields().Get(0)
121		index, _ := presenceIndex(mi.Desc, fd)
122		oi.which = func(p pointer) protoreflect.FieldNumber {
123			if p.IsNil() {
124				return 0
125			}
126			if !mi.present(p, index) {
127				return 0
128			}
129			return od.Fields().Get(0).Number()
130		}
131		return oi
132	}
133	// Dispatch to non-opaque oneof implementation for non-synthetic oneofs.
134	return makeOneofInfo(od, si, x)
135}
136
137func (mi *MessageInfo) fieldInfoForMapOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
138	ft := fs.Type
139	if ft.Kind() != reflect.Map {
140		panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
141	}
142	fieldOffset := offsetOf(fs)
143	conv := NewConverter(ft, fd)
144	return fieldInfo{
145		fieldDesc: fd,
146		has: func(p pointer) bool {
147			if p.IsNil() {
148				return false
149			}
150			// Don't bother checking presence bits, since we need to
151			// look at the map length even if the presence bit is set.
152			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
153			return rv.Len() > 0
154		},
155		clear: func(p pointer) {
156			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
157			rv.Set(reflect.Zero(rv.Type()))
158		},
159		get: func(p pointer) protoreflect.Value {
160			if p.IsNil() {
161				return conv.Zero()
162			}
163			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
164			if rv.Len() == 0 {
165				return conv.Zero()
166			}
167			return conv.PBValueOf(rv)
168		},
169		set: func(p pointer, v protoreflect.Value) {
170			pv := conv.GoValueOf(v)
171			if pv.IsNil() {
172				panic(fmt.Sprintf("invalid value: setting map field to read-only value"))
173			}
174			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
175			rv.Set(pv)
176		},
177		mutable: func(p pointer) protoreflect.Value {
178			v := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
179			if v.IsNil() {
180				v.Set(reflect.MakeMap(fs.Type))
181			}
182			return conv.PBValueOf(v)
183		},
184		newField: func() protoreflect.Value {
185			return conv.New()
186		},
187	}
188}
189
190func (mi *MessageInfo) fieldInfoForScalarListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
191	ft := fs.Type
192	if ft.Kind() != reflect.Slice {
193		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
194	}
195	conv := NewConverter(reflect.PtrTo(ft), fd)
196	fieldOffset := offsetOf(fs)
197	index, _ := presenceIndex(mi.Desc, fd)
198	return fieldInfo{
199		fieldDesc: fd,
200		has: func(p pointer) bool {
201			if p.IsNil() {
202				return false
203			}
204			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
205			return rv.Len() > 0
206		},
207		clear: func(p pointer) {
208			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
209			rv.Set(reflect.Zero(rv.Type()))
210		},
211		get: func(p pointer) protoreflect.Value {
212			if p.IsNil() {
213				return conv.Zero()
214			}
215			rv := p.Apply(fieldOffset).AsValueOf(fs.Type)
216			if rv.Elem().Len() == 0 {
217				return conv.Zero()
218			}
219			return conv.PBValueOf(rv)
220		},
221		set: func(p pointer, v protoreflect.Value) {
222			pv := conv.GoValueOf(v)
223			if pv.IsNil() {
224				panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
225			}
226			mi.setPresent(p, index)
227			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
228			rv.Set(pv.Elem())
229		},
230		mutable: func(p pointer) protoreflect.Value {
231			mi.setPresent(p, index)
232			return conv.PBValueOf(p.Apply(fieldOffset).AsValueOf(fs.Type))
233		},
234		newField: func() protoreflect.Value {
235			return conv.New()
236		},
237	}
238}
239
240func (mi *MessageInfo) fieldInfoForMessageListOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
241	ft := fs.Type
242	if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
243		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
244	}
245	conv := NewConverter(ft, fd)
246	fieldOffset := offsetOf(fs)
247	index, _ := presenceIndex(mi.Desc, fd)
248	fieldNumber := fd.Number()
249	return fieldInfo{
250		fieldDesc: fd,
251		has: func(p pointer) bool {
252			if p.IsNil() {
253				return false
254			}
255			if !mi.present(p, index) {
256				return false
257			}
258			sp := p.Apply(fieldOffset).AtomicGetPointer()
259			if sp.IsNil() {
260				// Lazily unmarshal this field.
261				mi.lazyUnmarshal(p, fieldNumber)
262				sp = p.Apply(fieldOffset).AtomicGetPointer()
263			}
264			rv := sp.AsValueOf(fs.Type.Elem())
265			return rv.Elem().Len() > 0
266		},
267		clear: func(p pointer) {
268			fp := p.Apply(fieldOffset)
269			sp := fp.AtomicGetPointer()
270			if sp.IsNil() {
271				sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
272				mi.setPresent(p, index)
273			}
274			rv := sp.AsValueOf(fs.Type.Elem())
275			rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
276		},
277		get: func(p pointer) protoreflect.Value {
278			if p.IsNil() {
279				return conv.Zero()
280			}
281			if !mi.present(p, index) {
282				return conv.Zero()
283			}
284			sp := p.Apply(fieldOffset).AtomicGetPointer()
285			if sp.IsNil() {
286				// Lazily unmarshal this field.
287				mi.lazyUnmarshal(p, fieldNumber)
288				sp = p.Apply(fieldOffset).AtomicGetPointer()
289			}
290			rv := sp.AsValueOf(fs.Type.Elem())
291			if rv.Elem().Len() == 0 {
292				return conv.Zero()
293			}
294			return conv.PBValueOf(rv)
295		},
296		set: func(p pointer, v protoreflect.Value) {
297			fp := p.Apply(fieldOffset)
298			sp := fp.AtomicGetPointer()
299			if sp.IsNil() {
300				sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
301				mi.setPresent(p, index)
302			}
303			rv := sp.AsValueOf(fs.Type.Elem())
304			val := conv.GoValueOf(v)
305			if val.IsNil() {
306				panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
307			} else {
308				rv.Elem().Set(val.Elem())
309			}
310		},
311		mutable: func(p pointer) protoreflect.Value {
312			fp := p.Apply(fieldOffset)
313			sp := fp.AtomicGetPointer()
314			if sp.IsNil() {
315				if mi.present(p, index) {
316					// Lazily unmarshal this field.
317					mi.lazyUnmarshal(p, fieldNumber)
318					sp = p.Apply(fieldOffset).AtomicGetPointer()
319				} else {
320					sp = fp.AtomicSetPointerIfNil(pointerOfValue(reflect.New(fs.Type.Elem())))
321					mi.setPresent(p, index)
322				}
323			}
324			rv := sp.AsValueOf(fs.Type.Elem())
325			return conv.PBValueOf(rv)
326		},
327		newField: func() protoreflect.Value {
328			return conv.New()
329		},
330	}
331}
332
333func (mi *MessageInfo) fieldInfoForMessageListOpaqueNoPresence(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
334	ft := fs.Type
335	if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
336		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
337	}
338	conv := NewConverter(ft, fd)
339	fieldOffset := offsetOf(fs)
340	return fieldInfo{
341		fieldDesc: fd,
342		has: func(p pointer) bool {
343			if p.IsNil() {
344				return false
345			}
346			sp := p.Apply(fieldOffset).AtomicGetPointer()
347			if sp.IsNil() {
348				return false
349			}
350			rv := sp.AsValueOf(fs.Type.Elem())
351			return rv.Elem().Len() > 0
352		},
353		clear: func(p pointer) {
354			sp := p.Apply(fieldOffset).AtomicGetPointer()
355			if !sp.IsNil() {
356				rv := sp.AsValueOf(fs.Type.Elem())
357				rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
358			}
359		},
360		get: func(p pointer) protoreflect.Value {
361			if p.IsNil() {
362				return conv.Zero()
363			}
364			sp := p.Apply(fieldOffset).AtomicGetPointer()
365			if sp.IsNil() {
366				return conv.Zero()
367			}
368			rv := sp.AsValueOf(fs.Type.Elem())
369			if rv.Elem().Len() == 0 {
370				return conv.Zero()
371			}
372			return conv.PBValueOf(rv)
373		},
374		set: func(p pointer, v protoreflect.Value) {
375			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
376			if rv.IsNil() {
377				rv.Set(reflect.New(fs.Type.Elem()))
378			}
379			val := conv.GoValueOf(v)
380			if val.IsNil() {
381				panic(fmt.Sprintf("invalid value: setting repeated field to read-only value"))
382			} else {
383				rv.Elem().Set(val.Elem())
384			}
385		},
386		mutable: func(p pointer) protoreflect.Value {
387			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
388			if rv.IsNil() {
389				rv.Set(reflect.New(fs.Type.Elem()))
390			}
391			return conv.PBValueOf(rv)
392		},
393		newField: func() protoreflect.Value {
394			return conv.New()
395		},
396	}
397}
398
399func (mi *MessageInfo) fieldInfoForScalarOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
400	ft := fs.Type
401	nullable := fd.HasPresence()
402	if oneof := fd.ContainingOneof(); oneof != nil && oneof.IsSynthetic() {
403		nullable = true
404	}
405	deref := false
406	if nullable && ft.Kind() == reflect.Ptr {
407		ft = ft.Elem()
408		deref = true
409	}
410	conv := NewConverter(ft, fd)
411	fieldOffset := offsetOf(fs)
412	index, _ := presenceIndex(mi.Desc, fd)
413	var getter func(p pointer) protoreflect.Value
414	if !nullable {
415		getter = getterForDirectScalar(fd, fs, conv, fieldOffset)
416	} else {
417		getter = getterForOpaqueNullableScalar(mi, index, fd, fs, conv, fieldOffset)
418	}
419	return fieldInfo{
420		fieldDesc: fd,
421		has: func(p pointer) bool {
422			if p.IsNil() {
423				return false
424			}
425			if nullable {
426				return mi.present(p, index)
427			}
428			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
429			switch rv.Kind() {
430			case reflect.Bool:
431				return rv.Bool()
432			case reflect.Int32, reflect.Int64:
433				return rv.Int() != 0
434			case reflect.Uint32, reflect.Uint64:
435				return rv.Uint() != 0
436			case reflect.Float32, reflect.Float64:
437				return rv.Float() != 0 || math.Signbit(rv.Float())
438			case reflect.String, reflect.Slice:
439				return rv.Len() > 0
440			default:
441				panic(fmt.Sprintf("invalid type: %v", rv.Type())) // should never happen
442			}
443		},
444		clear: func(p pointer) {
445			if nullable {
446				mi.clearPresent(p, index)
447			}
448			// This is only valuable for bytes and strings, but we do it unconditionally.
449			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
450			rv.Set(reflect.Zero(rv.Type()))
451		},
452		get: getter,
453		// TODO: Implement unsafe fast path for set?
454		set: func(p pointer, v protoreflect.Value) {
455			rv := p.Apply(fieldOffset).AsValueOf(fs.Type).Elem()
456			if deref {
457				if rv.IsNil() {
458					rv.Set(reflect.New(ft))
459				}
460				rv = rv.Elem()
461			}
462
463			rv.Set(conv.GoValueOf(v))
464			if nullable && rv.Kind() == reflect.Slice && rv.IsNil() {
465				rv.Set(emptyBytes)
466			}
467			if nullable {
468				mi.setPresent(p, index)
469			}
470		},
471		newField: func() protoreflect.Value {
472			return conv.New()
473		},
474	}
475}
476
477func (mi *MessageInfo) fieldInfoForMessageOpaque(si opaqueStructInfo, fd protoreflect.FieldDescriptor, fs reflect.StructField) fieldInfo {
478	ft := fs.Type
479	conv := NewConverter(ft, fd)
480	fieldOffset := offsetOf(fs)
481	index, _ := presenceIndex(mi.Desc, fd)
482	fieldNumber := fd.Number()
483	elemType := fs.Type.Elem()
484	return fieldInfo{
485		fieldDesc: fd,
486		has: func(p pointer) bool {
487			if p.IsNil() {
488				return false
489			}
490			return mi.present(p, index)
491		},
492		clear: func(p pointer) {
493			mi.clearPresent(p, index)
494			p.Apply(fieldOffset).AtomicSetNilPointer()
495		},
496		get: func(p pointer) protoreflect.Value {
497			if p.IsNil() || !mi.present(p, index) {
498				return conv.Zero()
499			}
500			fp := p.Apply(fieldOffset)
501			mp := fp.AtomicGetPointer()
502			if mp.IsNil() {
503				// Lazily unmarshal this field.
504				mi.lazyUnmarshal(p, fieldNumber)
505				mp = fp.AtomicGetPointer()
506			}
507			rv := mp.AsValueOf(elemType)
508			return conv.PBValueOf(rv)
509		},
510		set: func(p pointer, v protoreflect.Value) {
511			val := pointerOfValue(conv.GoValueOf(v))
512			if val.IsNil() {
513				panic("invalid nil pointer")
514			}
515			p.Apply(fieldOffset).AtomicSetPointer(val)
516			mi.setPresent(p, index)
517		},
518		mutable: func(p pointer) protoreflect.Value {
519			fp := p.Apply(fieldOffset)
520			mp := fp.AtomicGetPointer()
521			if mp.IsNil() {
522				if mi.present(p, index) {
523					// Lazily unmarshal this field.
524					mi.lazyUnmarshal(p, fieldNumber)
525					mp = fp.AtomicGetPointer()
526				} else {
527					mp = pointerOfValue(conv.GoValueOf(conv.New()))
528					fp.AtomicSetPointer(mp)
529					mi.setPresent(p, index)
530				}
531			}
532			return conv.PBValueOf(mp.AsValueOf(fs.Type.Elem()))
533		},
534		newMessage: func() protoreflect.Message {
535			return conv.New().Message()
536		},
537		newField: func() protoreflect.Value {
538			return conv.New()
539		},
540	}
541}
542
543// A presenceList wraps a List, updating presence bits as necessary when the
544// list contents change.
545type presenceList struct {
546	pvalueList
547	setPresence func(bool)
548}
549type pvalueList interface {
550	protoreflect.List
551	//Unwrapper
552}
553
554func (list presenceList) Append(v protoreflect.Value) {
555	list.pvalueList.Append(v)
556	list.setPresence(true)
557}
558func (list presenceList) Truncate(i int) {
559	list.pvalueList.Truncate(i)
560	list.setPresence(i > 0)
561}
562
563// presenceIndex returns the index to pass to presence functions.
564//
565// TODO: field.Desc.Index() would be simpler, and would give space to record the presence of oneof fields.
566func presenceIndex(md protoreflect.MessageDescriptor, fd protoreflect.FieldDescriptor) (uint32, presenceSize) {
567	found := false
568	var index, numIndices uint32
569	for i := 0; i < md.Fields().Len(); i++ {
570		f := md.Fields().Get(i)
571		if f == fd {
572			found = true
573			index = numIndices
574		}
575		if f.ContainingOneof() == nil || isLastOneofField(f) {
576			numIndices++
577		}
578	}
579	if !found {
580		panic(fmt.Sprintf("BUG: %v not in %v", fd.Name(), md.FullName()))
581	}
582	return index, presenceSize(numIndices)
583}
584
585func isLastOneofField(fd protoreflect.FieldDescriptor) bool {
586	fields := fd.ContainingOneof().Fields()
587	return fields.Get(fields.Len()-1) == fd
588}
589
590func (mi *MessageInfo) setPresent(p pointer, index uint32) {
591	p.Apply(mi.presenceOffset).PresenceInfo().SetPresent(index, mi.presenceSize)
592}
593
594func (mi *MessageInfo) clearPresent(p pointer, index uint32) {
595	p.Apply(mi.presenceOffset).PresenceInfo().ClearPresent(index)
596}
597
598func (mi *MessageInfo) present(p pointer, index uint32) bool {
599	return p.Apply(mi.presenceOffset).PresenceInfo().Present(index)
600}
601
602// usePresenceForField implements the somewhat intricate logic of when
603// the presence bitmap is used for a field.  The main logic is that a
604// field that is optional or that can be lazy will use the presence
605// bit, but for proto2, also maps have a presence bit. It also records
606// if the field can ever be lazy, which is true if we have a
607// lazyOffset and the field is a message or a slice of messages. A
608// field that is lazy will always need a presence bit.  Oneofs are not
609// lazy and do not use presence, unless they are a synthetic oneof,
610// which is a proto3 optional field. For proto3 optionals, we use the
611// presence and they can also be lazy when applicable (a message).
612func usePresenceForField(si opaqueStructInfo, fd protoreflect.FieldDescriptor) (usePresence, canBeLazy bool) {
613	hasLazyField := fd.(interface{ IsLazy() bool }).IsLazy()
614
615	// Non-oneof scalar fields with explicit field presence use the presence array.
616	usesPresenceArray := fd.HasPresence() && fd.Message() == nil && (fd.ContainingOneof() == nil || fd.ContainingOneof().IsSynthetic())
617	switch {
618	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
619		return false, false
620	case fd.IsMap():
621		return false, false
622	case fd.Kind() == protoreflect.MessageKind || fd.Kind() == protoreflect.GroupKind:
623		return hasLazyField, hasLazyField
624	default:
625		return usesPresenceArray || (hasLazyField && fd.HasPresence()), false
626	}
627}