blob: 0766b4b5b3798b86c045cdcdc4195495c54c3436 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package prototype provides constructors for protoreflect.EnumType,
// protoreflect.MessageType, and protoreflect.ExtensionType.
package prototype
import (
"fmt"
"reflect"
"sync"
"google.golang.org/protobuf/internal/descfmt"
"google.golang.org/protobuf/internal/value"
"google.golang.org/protobuf/reflect/protoreflect"
)
// Enum is a protoreflect.EnumType which combines a
// protoreflect.EnumDescriptor with a constructor function.
//
// Both EnumDescriptor and NewEnum must be populated.
// Once constructed, the exported fields must not be modified.
type Enum struct {
protoreflect.EnumDescriptor
// NewEnum constructs a new protoreflect.Enum representing the provided
// enum number. The returned Go type must be identical for every call.
NewEnum func(protoreflect.EnumNumber) protoreflect.Enum
once sync.Once
goType reflect.Type
}
func (t *Enum) New(n protoreflect.EnumNumber) protoreflect.Enum {
e := t.NewEnum(n)
t.once.Do(func() {
t.goType = reflect.TypeOf(e)
if e.Descriptor() != t.Descriptor() {
panic(fmt.Sprintf("mismatching enum descriptor: got %v, want %v", e.Descriptor(), t.Descriptor()))
}
if e.Descriptor().IsPlaceholder() {
panic("enum descriptor must not be a placeholder")
}
})
if t.goType != reflect.TypeOf(e) {
panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, t.goType))
}
return e
}
func (t *Enum) GoType() reflect.Type {
t.New(0) // initialize t.typ
return t.goType
}
func (t *Enum) Descriptor() protoreflect.EnumDescriptor {
return t.EnumDescriptor
}
func (t *Enum) Format(s fmt.State, r rune) {
descfmt.FormatDesc(s, r, t)
}
// Message is a protoreflect.MessageType which combines a
// protoreflect.MessageDescriptor with a constructor function.
//
// Both MessageDescriptor and NewMessage must be populated.
// Once constructed, the exported fields must not be modified.
type Message struct {
protoreflect.MessageDescriptor
// NewMessage constructs an empty, newly allocated protoreflect.Message.
// The returned Go type must be identical for every call.
NewMessage func() protoreflect.Message
once sync.Once
goType reflect.Type
}
func (t *Message) New() protoreflect.Message {
m := t.NewMessage()
mi := m.Interface()
t.once.Do(func() {
t.goType = reflect.TypeOf(mi)
if m.Descriptor() != t.Descriptor() {
panic(fmt.Sprintf("mismatching message descriptor: got %v, want %v", m.Descriptor(), t.Descriptor()))
}
if m.Descriptor().IsPlaceholder() {
panic("message descriptor must not be a placeholder")
}
})
if t.goType != reflect.TypeOf(mi) {
panic(fmt.Sprintf("mismatching types for message: got %T, want %v", mi, t.goType))
}
return m
}
func (t *Message) GoType() reflect.Type {
t.New() // initialize t.goType
return t.goType
}
func (t *Message) Descriptor() protoreflect.MessageDescriptor {
return t.MessageDescriptor
}
func (t *Message) Format(s fmt.State, r rune) {
descfmt.FormatDesc(s, r, t)
}
// Extension is a protoreflect.ExtensionType which combines a
// protoreflect.ExtensionDescriptor with a constructor function.
//
// ExtensionDescriptor must be populated, while NewEnum or NewMessage must
// populated depending on the kind of the extension field.
// Once constructed, the exported fields must not be modified.
type Extension struct {
protoreflect.ExtensionDescriptor
// NewEnum constructs a new enum (see Enum.NewEnum).
// This must be populated if and only if ExtensionDescriptor.Kind
// is a protoreflect.EnumKind.
NewEnum func(protoreflect.EnumNumber) protoreflect.Enum
// NewMessage constructs a new message (see Enum.NewMessage).
// This must be populated if and only if ExtensionDescriptor.Kind
// is a protoreflect.MessageKind or protoreflect.GroupKind.
NewMessage func() protoreflect.Message
// TODO: Allow users to manually set new, valueOf, and interfaceOf.
// This allows users to implement custom composite types (e.g., List) or
// custom Go types for primitives (e.g., int32).
once sync.Once
goType reflect.Type
new func() protoreflect.Value
valueOf func(v interface{}) protoreflect.Value
interfaceOf func(v protoreflect.Value) interface{}
}
func (t *Extension) New() protoreflect.Value {
t.lazyInit()
pv := t.new()
v := t.interfaceOf(pv)
if reflect.TypeOf(v) != t.goType {
panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.goType))
}
return pv
}
func (t *Extension) ValueOf(v interface{}) protoreflect.Value {
t.lazyInit()
if reflect.TypeOf(v) != t.goType {
panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.goType))
}
return t.valueOf(v)
}
func (t *Extension) InterfaceOf(v protoreflect.Value) interface{} {
t.lazyInit()
vi := t.interfaceOf(v)
if reflect.TypeOf(vi) != t.goType {
panic(fmt.Sprintf("invalid type: got %T, want %v", vi, t.goType))
}
return vi
}
// GoType is the type of the extension field.
// The type is T for scalars and *[]T for lists (maps are not allowed).
// The type T is determined as follows:
//
// +------------+-------------------------------------+
// | Go type | Protobuf kind |
// +------------+-------------------------------------+
// | bool | BoolKind |
// | int32 | Int32Kind, Sint32Kind, Sfixed32Kind |
// | int64 | Int64Kind, Sint64Kind, Sfixed64Kind |
// | uint32 | Uint32Kind, Fixed32Kind |
// | uint64 | Uint64Kind, Fixed64Kind |
// | float32 | FloatKind |
// | float64 | DoubleKind |
// | string | StringKind |
// | []byte | BytesKind |
// | E | EnumKind |
// | M | MessageKind, GroupKind |
// +------------+-------------------------------------+
//
// The type E is the concrete enum type returned by NewEnum,
// which is often, but not required to be, a named int32 type.
// The type M is the concrete message type returned by NewMessage,
// which is often, but not required to be, a pointer to a named struct type.
func (t *Extension) GoType() reflect.Type {
t.lazyInit()
return t.goType
}
func (t *Extension) Descriptor() protoreflect.ExtensionDescriptor {
return t.ExtensionDescriptor
}
func (t *Extension) Format(s fmt.State, r rune) {
descfmt.FormatDesc(s, r, t)
}
func (t *Extension) lazyInit() {
t.once.Do(func() {
switch t.Kind() {
case protoreflect.EnumKind:
if t.NewEnum == nil || t.NewMessage != nil {
panic("NewEnum alone must be set")
}
e := t.NewEnum(0)
if e.Descriptor() != t.Enum() {
panic(fmt.Sprintf("mismatching enum descriptor: got %v, want %v", e.Descriptor(), t.Enum()))
}
t.goType = reflect.TypeOf(e)
case protoreflect.MessageKind, protoreflect.GroupKind:
if t.NewEnum != nil || t.NewMessage == nil {
panic("NewMessage alone must be set")
}
m := t.NewMessage()
if m.Descriptor() != t.Message() {
panic(fmt.Sprintf("mismatching message descriptor: got %v, want %v", m.Descriptor(), t.Message()))
}
t.goType = reflect.TypeOf(m.Interface())
default:
if t.NewEnum != nil || t.NewMessage != nil {
panic("neither NewEnum nor NewMessage may be set")
}
t.goType = goTypeForPBKind[t.Kind()]
}
switch t.Cardinality() {
case protoreflect.Optional:
switch t.Kind() {
case protoreflect.EnumKind:
t.new = func() protoreflect.Value {
return t.Default()
}
t.valueOf = func(v interface{}) protoreflect.Value {
ev := v.(protoreflect.Enum)
return protoreflect.ValueOf(ev.Number())
}
t.interfaceOf = func(v protoreflect.Value) interface{} {
return t.NewEnum(v.Enum())
}
case protoreflect.MessageKind, protoreflect.GroupKind:
t.new = func() protoreflect.Value {
return protoreflect.ValueOf(t.NewMessage())
}
t.valueOf = func(v interface{}) protoreflect.Value {
mv := v.(protoreflect.ProtoMessage).ProtoReflect()
return protoreflect.ValueOf(mv)
}
t.interfaceOf = func(v protoreflect.Value) interface{} {
return v.Message().Interface()
}
default:
t.new = func() protoreflect.Value {
v := t.Default()
if t.Kind() == protoreflect.BytesKind {
// Copy default bytes to avoid aliasing the original.
v = protoreflect.ValueOf(append([]byte(nil), v.Bytes()...))
}
return v
}
t.valueOf = func(v interface{}) protoreflect.Value {
return protoreflect.ValueOf(v)
}
t.interfaceOf = func(v protoreflect.Value) interface{} {
return v.Interface()
}
}
case protoreflect.Repeated:
var conv value.Converter
elemType := t.goType
switch t.Kind() {
case protoreflect.EnumKind:
conv = value.Converter{
PBValueOf: func(v reflect.Value) protoreflect.Value {
if v.Type() != elemType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), elemType))
}
e := v.Interface().(protoreflect.Enum)
return protoreflect.ValueOf(e.Number())
},
GoValueOf: func(v protoreflect.Value) reflect.Value {
rv := reflect.ValueOf(t.NewEnum(v.Enum()))
if rv.Type() != elemType {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), elemType))
}
return rv
},
NewEnum: t.NewEnum,
}
case protoreflect.MessageKind, protoreflect.GroupKind:
conv = value.Converter{
PBValueOf: func(v reflect.Value) protoreflect.Value {
if v.Type() != elemType {
panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), elemType))
}
m := v.Interface().(protoreflect.ProtoMessage).ProtoReflect()
return protoreflect.ValueOf(m)
},
GoValueOf: func(v protoreflect.Value) reflect.Value {
rv := reflect.ValueOf(v.Message().Interface())
if rv.Type() != elemType {
panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), elemType))
}
return rv
},
NewMessage: t.NewMessage,
}
default:
conv = value.NewConverter(elemType, t.Kind())
}
t.goType = reflect.PtrTo(reflect.SliceOf(elemType))
t.new = func() protoreflect.Value {
v := reflect.New(t.goType.Elem()).Interface()
return protoreflect.ValueOf(value.ListOf(v, conv))
}
t.valueOf = func(v interface{}) protoreflect.Value {
return protoreflect.ValueOf(value.ListOf(v, conv))
}
t.interfaceOf = func(v protoreflect.Value) interface{} {
return v.List().(value.Unwrapper).ProtoUnwrap()
}
default:
panic(fmt.Sprintf("invalid cardinality: %v", t.Cardinality()))
}
})
}
var goTypeForPBKind = map[protoreflect.Kind]reflect.Type{
protoreflect.BoolKind: reflect.TypeOf(bool(false)),
protoreflect.Int32Kind: reflect.TypeOf(int32(0)),
protoreflect.Sint32Kind: reflect.TypeOf(int32(0)),
protoreflect.Sfixed32Kind: reflect.TypeOf(int32(0)),
protoreflect.Int64Kind: reflect.TypeOf(int64(0)),
protoreflect.Sint64Kind: reflect.TypeOf(int64(0)),
protoreflect.Sfixed64Kind: reflect.TypeOf(int64(0)),
protoreflect.Uint32Kind: reflect.TypeOf(uint32(0)),
protoreflect.Fixed32Kind: reflect.TypeOf(uint32(0)),
protoreflect.Uint64Kind: reflect.TypeOf(uint64(0)),
protoreflect.Fixed64Kind: reflect.TypeOf(uint64(0)),
protoreflect.FloatKind: reflect.TypeOf(float32(0)),
protoreflect.DoubleKind: reflect.TypeOf(float64(0)),
protoreflect.StringKind: reflect.TypeOf(string("")),
protoreflect.BytesKind: reflect.TypeOf([]byte(nil)),
}
var (
_ protoreflect.EnumType = (*Enum)(nil)
_ protoreflect.MessageType = (*Message)(nil)
_ protoreflect.ExtensionType = (*Extension)(nil)
)