internal/impl: support message and enum fields
Dynamically generate functions for handling message and enum fields,
regardless of whether they are of the v1 or v2 forms.
If a v1 message is encountered, it is automatically wrapped such that it
implements the v2 interface.
Change-Id: I457bc5286892e8fc00a61da7062dd33058daafd5
Reviewed-on: https://go-review.googlesource.com/c/143837
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index a617249..66a71c6 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -18,10 +18,26 @@
ptype "github.com/golang/protobuf/v2/reflect/prototype"
)
+var messageTypeCache sync.Map // map[reflect.Type]*MessageType
+
+// wrapLegacyMessage wraps v as a protoreflect.Message, where v must be
+// a *struct kind and not implement the v2 API already.
+func wrapLegacyMessage(v reflect.Value) pref.Message {
+ // Fast-path: check if a MessageType is cached for this concrete type.
+ if mt, ok := messageTypeCache.Load(v.Type()); ok {
+ return mt.(*MessageType).MessageOf(v.Interface())
+ }
+
+ // Slow-path: derive message descriptor and initialize MessageType.
+ mt := &MessageType{Desc: loadMessageDesc(v.Type())}
+ messageTypeCache.Store(v.Type(), mt)
+ return mt.MessageOf(v.Interface())
+}
+
var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
// loadMessageDesc returns an MessageDescriptor derived from the Go type,
-// which must be an *struct kind and not implement the v2 API already.
+// which must be a *struct kind and not implement the v2 API already.
func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
return messageDescSet{}.Load(t)
}
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index bdab90a..140f5cd 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -60,8 +60,7 @@
rv := p.apply(fieldOffset).asType(fs.Type).Elem()
if rv.IsNil() || rv.Elem().Type().Elem() != ot {
if fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind {
- // Return a typed nil pointer of the message type to be
- // consistent with the behavior of generated getters.
+ // TODO: Should this return an invalid protoreflect.Value?
rv = reflect.Zero(ot.Field(0).Type)
return conv.toPB(rv)
}
@@ -198,6 +197,8 @@
}
func (ms mapReflect) ProtoMutable() {}
+var _ pref.Map = mapReflect{}
+
func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
ft := fs.Type
if ft.Kind() != reflect.Slice {
@@ -349,8 +350,44 @@
}
func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
- // TODO: support vector fields.
- panic(fmt.Sprintf("invalid field: %v", fd))
+ ft := fs.Type
+ conv := matchGoTypePBKind(ft, fd.Kind())
+ fieldOffset := offsetOf(fs)
+ // TODO: Implement unsafe fast path?
+ return fieldInfo{
+ has: func(p pointer) bool {
+ rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+ return !rv.IsNil()
+ },
+ get: func(p pointer) pref.Value {
+ // TODO: If rv.IsNil(), should this return a typed-nil pointer or
+ // an invalid protoreflect.Value?
+ //
+ // Returning a typed nil pointer assumes that such values
+ // are valid for all possible custom message types,
+ // which may not be case for dynamic messages.
+ rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+ return conv.toPB(rv)
+ },
+ set: func(p pointer, v pref.Value) {
+ // TODO: Similarly, is it valid to set this to a typed nil pointer?
+ rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+ rv.Set(conv.toGo(v))
+ },
+ clear: func(p pointer) {
+ rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+ rv.Set(reflect.Zero(rv.Type()))
+ },
+ mutable: func(p pointer) pref.Mutable {
+ // Mutable is only valid for messages and panics for other kinds.
+ rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+ if rv.IsNil() {
+ pv := pref.ValueOf(conv.newMessage())
+ rv.Set(conv.toGo(pv))
+ }
+ return conv.toPB(rv).Message()
+ },
+ }
}
// messageV1 is the protoV1.Message interface.
@@ -424,22 +461,84 @@
case pref.EnumKind:
// Handle v2 enums, which must satisfy the proto.Enum interface.
if t.Kind() != reflect.Ptr && t.Implements(enumIfaceV2) {
- // TODO: implement this.
+ et := reflect.Zero(t).Interface().(pref.ProtoEnum).ProtoReflect().Type()
+ return converter{
+ toPB: func(v reflect.Value) pref.Value {
+ if v.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+ }
+ e := v.Interface().(pref.ProtoEnum)
+ return pref.ValueOf(e.ProtoReflect().Number())
+ },
+ toGo: func(v pref.Value) reflect.Value {
+ rv := reflect.ValueOf(et.GoNew(v.Enum()))
+ if rv.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+ }
+ return rv
+ },
+ }
}
// Handle v1 enums, which we identify as simply a named int32 type.
if t.Kind() == reflect.Int32 && t.PkgPath() != "" {
- // TODO: need logic to wrap a legacy enum to implement this.
+ return converter{
+ toPB: func(v reflect.Value) pref.Value {
+ if v.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+ }
+ return pref.ValueOf(pref.EnumNumber(v.Int()))
+ },
+ toGo: func(v pref.Value) reflect.Value {
+ return reflect.ValueOf(v.Enum()).Convert(t)
+ },
+ }
}
case pref.MessageKind, pref.GroupKind:
// Handle v2 messages, which must satisfy the proto.Message interface.
if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV2) {
- // TODO: implement this.
+ mt := reflect.Zero(t).Interface().(pref.ProtoMessage).ProtoReflect().Type()
+ return converter{
+ toPB: func(v reflect.Value) pref.Value {
+ if v.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+ }
+ return pref.ValueOf(v.Interface())
+ },
+ toGo: func(v pref.Value) reflect.Value {
+ rv := reflect.ValueOf(v.Message())
+ if rv.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+ }
+ return rv
+ },
+ newMessage: func() pref.Message {
+ return mt.GoNew().ProtoReflect()
+ },
+ }
}
// Handle v1 messages, which we need to wrap as a v2 message.
if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
- // TODO: need logic to wrap a legacy message.
+ return converter{
+ toPB: func(v reflect.Value) pref.Value {
+ if v.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+ }
+ return pref.ValueOf(wrapLegacyMessage(v))
+ },
+ toGo: func(v pref.Value) reflect.Value {
+ type unwrapper interface{ Unwrap() interface{} }
+ rv := reflect.ValueOf(v.Message().(unwrapper).Unwrap())
+ if rv.Type() != t {
+ panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+ }
+ return rv
+ },
+ newMessage: func() pref.Message {
+ return wrapLegacyMessage(reflect.New(t.Elem()))
+ },
+ }
}
}
panic(fmt.Sprintf("invalid Go type %v for protobuf kind %v", t, k))
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index ce8bc70..4041f02 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -41,13 +41,34 @@
MyString string
MyBytes []byte
- NamedStrings []MyString
- NamedBytes []MyBytes
+ VectorStrings []MyString
+ VectorBytes []MyBytes
MapStrings map[MyString]MyString
MapBytes map[MyString]MyBytes
+
+ MyEnumV1 pref.EnumNumber
+ MyEnumV2 string
+ myEnumV2 MyEnumV2
+
+ MyMessageV1 struct {
+ // SubMessage *Message
+ }
+ MyMessageV2 map[pref.FieldNumber]pref.Value
+ myMessageV2 MyMessageV2
)
+func (e MyEnumV2) ProtoReflect() pref.Enum { return myEnumV2(e) }
+func (e myEnumV2) Type() pref.EnumType { return nil } // TODO
+func (e myEnumV2) Number() pref.EnumNumber { return 0 } // TODO
+
+func (m *MyMessageV2) ProtoReflect() pref.Message { return (*myMessageV2)(m) }
+func (m *myMessageV2) Type() pref.MessageType { return nil } // TODO
+func (m *myMessageV2) KnownFields() pref.KnownFields { return nil } // TODO
+func (m *myMessageV2) UnknownFields() pref.UnknownFields { return nil } // TODO
+func (m *myMessageV2) Interface() pref.ProtoMessage { return (*MyMessageV2)(m) }
+func (m *myMessageV2) ProtoMutable() {}
+
// List of test operations to perform on messages, vectors, or maps.
type (
messageOp interface{} // equalMessage | hasFields | getFields | setFields | clearFields | vectorFields | mapFields
@@ -96,33 +117,33 @@
// TODO: Mutable
)
+type ScalarProto2 struct {
+ Bool *bool `protobuf:"1"`
+ Int32 *int32 `protobuf:"2"`
+ Int64 *int64 `protobuf:"3"`
+ Uint32 *uint32 `protobuf:"4"`
+ Uint64 *uint64 `protobuf:"5"`
+ Float32 *float32 `protobuf:"6"`
+ Float64 *float64 `protobuf:"7"`
+ String *string `protobuf:"8"`
+ StringA []byte `protobuf:"9"`
+ Bytes []byte `protobuf:"10"`
+ BytesA *string `protobuf:"11"`
+
+ MyBool *MyBool `protobuf:"12"`
+ MyInt32 *MyInt32 `protobuf:"13"`
+ MyInt64 *MyInt64 `protobuf:"14"`
+ MyUint32 *MyUint32 `protobuf:"15"`
+ MyUint64 *MyUint64 `protobuf:"16"`
+ MyFloat32 *MyFloat32 `protobuf:"17"`
+ MyFloat64 *MyFloat64 `protobuf:"18"`
+ MyString *MyString `protobuf:"19"`
+ MyStringA MyBytes `protobuf:"20"`
+ MyBytes MyBytes `protobuf:"21"`
+ MyBytesA *MyString `protobuf:"22"`
+}
+
func TestScalarProto2(t *testing.T) {
- type ScalarProto2 struct {
- Bool *bool `protobuf:"1"`
- Int32 *int32 `protobuf:"2"`
- Int64 *int64 `protobuf:"3"`
- Uint32 *uint32 `protobuf:"4"`
- Uint64 *uint64 `protobuf:"5"`
- Float32 *float32 `protobuf:"6"`
- Float64 *float64 `protobuf:"7"`
- String *string `protobuf:"8"`
- StringA []byte `protobuf:"9"`
- Bytes []byte `protobuf:"10"`
- BytesA *string `protobuf:"11"`
-
- MyBool *MyBool `protobuf:"12"`
- MyInt32 *MyInt32 `protobuf:"13"`
- MyInt64 *MyInt64 `protobuf:"14"`
- MyUint32 *MyUint32 `protobuf:"15"`
- MyUint64 *MyUint64 `protobuf:"16"`
- MyFloat32 *MyFloat32 `protobuf:"17"`
- MyFloat64 *MyFloat64 `protobuf:"18"`
- MyString *MyString `protobuf:"19"`
- MyStringA MyBytes `protobuf:"20"`
- MyBytes MyBytes `protobuf:"21"`
- MyBytesA *MyString `protobuf:"22"`
- }
-
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "ScalarProto2",
@@ -182,33 +203,33 @@
})
}
+type ScalarProto3 struct {
+ Bool bool `protobuf:"1"`
+ Int32 int32 `protobuf:"2"`
+ Int64 int64 `protobuf:"3"`
+ Uint32 uint32 `protobuf:"4"`
+ Uint64 uint64 `protobuf:"5"`
+ Float32 float32 `protobuf:"6"`
+ Float64 float64 `protobuf:"7"`
+ String string `protobuf:"8"`
+ StringA []byte `protobuf:"9"`
+ Bytes []byte `protobuf:"10"`
+ BytesA string `protobuf:"11"`
+
+ MyBool MyBool `protobuf:"12"`
+ MyInt32 MyInt32 `protobuf:"13"`
+ MyInt64 MyInt64 `protobuf:"14"`
+ MyUint32 MyUint32 `protobuf:"15"`
+ MyUint64 MyUint64 `protobuf:"16"`
+ MyFloat32 MyFloat32 `protobuf:"17"`
+ MyFloat64 MyFloat64 `protobuf:"18"`
+ MyString MyString `protobuf:"19"`
+ MyStringA MyBytes `protobuf:"20"`
+ MyBytes MyBytes `protobuf:"21"`
+ MyBytesA MyString `protobuf:"22"`
+}
+
func TestScalarProto3(t *testing.T) {
- type ScalarProto3 struct {
- Bool bool `protobuf:"1"`
- Int32 int32 `protobuf:"2"`
- Int64 int64 `protobuf:"3"`
- Uint32 uint32 `protobuf:"4"`
- Uint64 uint64 `protobuf:"5"`
- Float32 float32 `protobuf:"6"`
- Float64 float64 `protobuf:"7"`
- String string `protobuf:"8"`
- StringA []byte `protobuf:"9"`
- Bytes []byte `protobuf:"10"`
- BytesA string `protobuf:"11"`
-
- MyBool MyBool `protobuf:"12"`
- MyInt32 MyInt32 `protobuf:"13"`
- MyInt64 MyInt64 `protobuf:"14"`
- MyUint32 MyUint32 `protobuf:"15"`
- MyUint64 MyUint64 `protobuf:"16"`
- MyFloat32 MyFloat32 `protobuf:"17"`
- MyFloat64 MyFloat64 `protobuf:"18"`
- MyString MyString `protobuf:"19"`
- MyStringA MyBytes `protobuf:"20"`
- MyBytes MyBytes `protobuf:"21"`
- MyBytesA MyString `protobuf:"22"`
- }
-
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto3,
FullName: "ScalarProto3",
@@ -277,31 +298,31 @@
})
}
+type RepeatedScalars struct {
+ Bools []bool `protobuf:"1"`
+ Int32s []int32 `protobuf:"2"`
+ Int64s []int64 `protobuf:"3"`
+ Uint32s []uint32 `protobuf:"4"`
+ Uint64s []uint64 `protobuf:"5"`
+ Float32s []float32 `protobuf:"6"`
+ Float64s []float64 `protobuf:"7"`
+ Strings []string `protobuf:"8"`
+ StringsA [][]byte `protobuf:"9"`
+ Bytes [][]byte `protobuf:"10"`
+ BytesA []string `protobuf:"11"`
+
+ MyStrings1 []MyString `protobuf:"12"`
+ MyStrings2 []MyBytes `protobuf:"13"`
+ MyBytes1 []MyBytes `protobuf:"14"`
+ MyBytes2 []MyString `protobuf:"15"`
+
+ MyStrings3 VectorStrings `protobuf:"16"`
+ MyStrings4 VectorBytes `protobuf:"17"`
+ MyBytes3 VectorBytes `protobuf:"18"`
+ MyBytes4 VectorStrings `protobuf:"19"`
+}
+
func TestRepeatedScalars(t *testing.T) {
- type RepeatedScalars struct {
- Bools []bool `protobuf:"1"`
- Int32s []int32 `protobuf:"2"`
- Int64s []int64 `protobuf:"3"`
- Uint32s []uint32 `protobuf:"4"`
- Uint64s []uint64 `protobuf:"5"`
- Float32s []float32 `protobuf:"6"`
- Float64s []float64 `protobuf:"7"`
- Strings []string `protobuf:"8"`
- StringsA [][]byte `protobuf:"9"`
- Bytes [][]byte `protobuf:"10"`
- BytesA []string `protobuf:"11"`
-
- MyStrings1 []MyString `protobuf:"12"`
- MyStrings2 []MyBytes `protobuf:"13"`
- MyBytes1 []MyBytes `protobuf:"14"`
- MyBytes2 []MyString `protobuf:"15"`
-
- MyStrings3 NamedStrings `protobuf:"16"`
- MyStrings4 NamedBytes `protobuf:"17"`
- MyBytes3 NamedBytes `protobuf:"18"`
- MyBytes4 NamedStrings `protobuf:"19"`
- }
-
mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
Syntax: pref.Proto2,
FullName: "RepeatedScalars",
@@ -351,10 +372,10 @@
MyBytes1: []MyBytes{[]byte("14"), nil, []byte("fourteen")},
MyBytes2: []MyString{"15", "", "fifteen"},
- MyStrings3: NamedStrings{"16", "", "sixteen"},
- MyStrings4: NamedBytes{[]byte("17"), nil, []byte("seventeen")},
- MyBytes3: NamedBytes{[]byte("18"), nil, []byte("eighteen")},
- MyBytes4: NamedStrings{"19", "", "nineteen"},
+ MyStrings3: VectorStrings{"16", "", "sixteen"},
+ MyStrings4: VectorBytes{[]byte("17"), nil, []byte("seventeen")},
+ MyBytes3: VectorBytes{[]byte("18"), nil, []byte("eighteen")},
+ MyBytes4: VectorStrings{"19", "", "nineteen"},
})
wantFS := want.KnownFields()
@@ -418,38 +439,38 @@
})
}
+type MapScalars struct {
+ KeyBools map[bool]string `protobuf:"1"`
+ KeyInt32s map[int32]string `protobuf:"2"`
+ KeyInt64s map[int64]string `protobuf:"3"`
+ KeyUint32s map[uint32]string `protobuf:"4"`
+ KeyUint64s map[uint64]string `protobuf:"5"`
+ KeyStrings map[string]string `protobuf:"6"`
+
+ ValBools map[string]bool `protobuf:"7"`
+ ValInt32s map[string]int32 `protobuf:"8"`
+ ValInt64s map[string]int64 `protobuf:"9"`
+ ValUint32s map[string]uint32 `protobuf:"10"`
+ ValUint64s map[string]uint64 `protobuf:"11"`
+ ValFloat32s map[string]float32 `protobuf:"12"`
+ ValFloat64s map[string]float64 `protobuf:"13"`
+ ValStrings map[string]string `protobuf:"14"`
+ ValStringsA map[string][]byte `protobuf:"15"`
+ ValBytes map[string][]byte `protobuf:"16"`
+ ValBytesA map[string]string `protobuf:"17"`
+
+ MyStrings1 map[MyString]MyString `protobuf:"18"`
+ MyStrings2 map[MyString]MyBytes `protobuf:"19"`
+ MyBytes1 map[MyString]MyBytes `protobuf:"20"`
+ MyBytes2 map[MyString]MyString `protobuf:"21"`
+
+ MyStrings3 MapStrings `protobuf:"22"`
+ MyStrings4 MapBytes `protobuf:"23"`
+ MyBytes3 MapBytes `protobuf:"24"`
+ MyBytes4 MapStrings `protobuf:"25"`
+}
+
func TestMapScalars(t *testing.T) {
- type MapScalars struct {
- KeyBools map[bool]string `protobuf:"1"`
- KeyInt32s map[int32]string `protobuf:"2"`
- KeyInt64s map[int64]string `protobuf:"3"`
- KeyUint32s map[uint32]string `protobuf:"4"`
- KeyUint64s map[uint64]string `protobuf:"5"`
- KeyStrings map[string]string `protobuf:"6"`
-
- ValBools map[string]bool `protobuf:"7"`
- ValInt32s map[string]int32 `protobuf:"8"`
- ValInt64s map[string]int64 `protobuf:"9"`
- ValUint32s map[string]uint32 `protobuf:"10"`
- ValUint64s map[string]uint64 `protobuf:"11"`
- ValFloat32s map[string]float32 `protobuf:"12"`
- ValFloat64s map[string]float64 `protobuf:"13"`
- ValStrings map[string]string `protobuf:"14"`
- ValStringsA map[string][]byte `protobuf:"15"`
- ValBytes map[string][]byte `protobuf:"16"`
- ValBytesA map[string]string `protobuf:"17"`
-
- MyStrings1 map[MyString]MyString `protobuf:"18"`
- MyStrings2 map[MyString]MyBytes `protobuf:"19"`
- MyBytes1 map[MyString]MyBytes `protobuf:"20"`
- MyBytes2 map[MyString]MyString `protobuf:"21"`
-
- MyStrings3 MapStrings `protobuf:"22"`
- MyStrings4 MapBytes `protobuf:"23"`
- MyBytes3 MapBytes `protobuf:"24"`
- MyBytes4 MapStrings `protobuf:"25"`
- }
-
mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
return ptype.Field{
Name: pref.Name(fmt.Sprintf("f%d", n)),