types/dynamicpb: support dynamic extensions
Add a dynamicpb.NewExtensionType function to permit creating extension
types from descriptors.
Also fix a some bugs around extension field handling:
When creating a new value for an extension field, use the
ExtensionType's Zero or New method to create the value.
Ensure that prototest exercises true zero-values of fields. (i.e.,
getting a list, map, or message from an empty message rather than
creating a new empty one with NewField.)
Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index fbfe7c6..20d20c9 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -29,6 +29,12 @@
//
// If nil, TestMessage will look for extension types in the global registry.
ExtensionTypes []pref.ExtensionType
+
+ // Resolver is used for looking up types when unmarshaling extension fields.
+ // If nil, this defaults to using protoregistry.GlobalTypes.
+ Resolver interface {
+ preg.ExtensionTypeResolver
+ }
}
// TestMessage runs the provided m through a series of tests
@@ -57,12 +63,20 @@
// Test round-trip marshal/unmarshal.
m2 := m.ProtoReflect().New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil)
- b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
+ for _, xt := range opts.ExtensionTypes {
+ m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
+ }
+ b, err := proto.MarshalOptions{
+ AllowPartial: true,
+ }.Marshal(m2)
if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
}
m3 := m.ProtoReflect().New().Interface()
- if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
+ if err := (proto.UnmarshalOptions{
+ AllowPartial: true,
+ Resolver: opts.Resolver,
+ }.Unmarshal(b, m3)); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
}
if !proto.Equal(m2, m3) {
@@ -150,7 +164,7 @@
}
case fd.IsMap():
if got := m.Get(fd); got.Map().Len() != 0 {
- t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
}
case fd.Message() == nil:
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
@@ -158,6 +172,21 @@
}
}
+ // Set to the default value.
+ switch {
+ case fd.IsList() || fd.IsMap():
+ m.Set(fd, m.Get(fd))
+ if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
+ t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+ }
+ case fd.Message() == nil:
+ m.Set(fd, m.Get(fd))
+ if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+ }
+ }
+ m.Clear(fd)
+
// Set to the wrong type.
v := pref.ValueOf("")
if fd.Kind() == pref.StringKind {
@@ -508,26 +537,29 @@
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
switch {
case fd.IsList():
- list := m.NewField(fd).List()
if n == 0 {
- return pref.ValueOf(list)
+ return m.New().Get(fd)
}
+ list := m.NewField(fd).List()
list.Append(newListElement(fd, list, 0, stack))
list.Append(newListElement(fd, list, minVal, stack))
list.Append(newListElement(fd, list, maxVal, stack))
list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list)
case fd.IsMap():
- mapv := m.NewField(fd).Map()
if n == 0 {
- return pref.ValueOf(mapv)
+ return m.New().Get(fd)
}
+ mapv := m.NewField(fd).Map()
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
return pref.ValueOf(mapv)
case fd.Message() != nil:
+ //if n == 0 {
+ // return m.New().Get(fd)
+ //}
return populateMessage(m.NewField(fd).Message(), n, stack)
default:
return newScalarValue(fd, n)
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 7b8c8d0..06616a1 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -122,16 +122,22 @@
func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd)
num := fd.Number()
- if v, ok := m.known[num]; ok {
- if !fd.IsExtension() || fd == m.ext[num] {
- return v
+ if fd.IsExtension() {
+ if fd != m.ext[num] {
+ return fd.(pref.ExtensionTypeDescriptor).Type().Zero()
}
+ return m.known[num]
+ }
+ if v, ok := m.known[num]; ok {
+ return v
}
switch {
case fd.IsMap():
return pref.ValueOf(&dynamicMap{desc: fd})
- case fd.Cardinality() == pref.Repeated:
+ case fd.IsList():
return pref.ValueOf(emptyList{desc: fd})
+ case fd.Message() != nil:
+ return pref.ValueOf(&Message{desc: fd.Message()})
case fd.Kind() == pref.BytesKind:
return pref.ValueOf(append([]byte(nil), fd.Default().Bytes()...))
default:
@@ -143,15 +149,23 @@
// See protoreflect.Message for details.
func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
m.checkField(fd)
- num := fd.Number()
- if v, ok := m.known[num]; ok {
- if !fd.IsExtension() || fd == m.ext[num] {
- return v
- }
- }
if !fd.IsMap() && !fd.IsList() && fd.Message() == nil {
panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
}
+ if m.known == nil {
+ panic(errors.New("%v: modification of read-only message", fd.FullName()))
+ }
+ num := fd.Number()
+ if fd.IsExtension() {
+ if fd != m.ext[num] {
+ m.ext[num] = fd
+ m.known[num] = fd.(pref.ExtensionTypeDescriptor).Type().New()
+ }
+ return m.known[num]
+ }
+ if v, ok := m.known[num]; ok {
+ return v
+ }
m.clearOtherOneofFields(fd)
m.known[num] = m.NewField(fd)
if fd.IsExtension() {
@@ -164,22 +178,16 @@
// See protoreflect.Message for details.
func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
m.checkField(fd)
- switch {
- case fd.IsExtension():
+ if m.known == nil {
+ panic(errors.New("%v: modification of read-only message", fd.FullName()))
+ }
+ if fd.IsExtension() {
if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
}
m.ext[fd.Number()] = fd
- case fd.IsMap():
- if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
- panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
- }
- case fd.IsList():
- if list, ok := v.Interface().(*dynamicList); !ok || list.desc != fd {
- panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
- }
- default:
- typecheckSingular(fd, v)
+ } else {
+ typecheck(fd, v)
}
m.clearOtherOneofFields(fd)
m.known[fd.Number()] = v
@@ -251,6 +259,9 @@
// SetUnknown sets the raw unknown fields.
// See protoreflect.Message for details.
func (m *Message) SetUnknown(r pref.RawFields) {
+ if m.known == nil {
+ panic(errors.New("%v: modification of read-only message", m.desc.FullName()))
+ }
m.unknown = r
}
@@ -406,7 +417,43 @@
return true
}
+func typecheck(fd pref.FieldDescriptor, v pref.Value) {
+ if err := typeIsValid(fd, v); err != nil {
+ panic(err)
+ }
+}
+
+func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
+ switch {
+ case fd.IsMap():
+ if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
+ return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
+ }
+ return nil
+ case fd.IsList():
+ switch list := v.Interface().(type) {
+ case *dynamicList:
+ if list.desc == fd {
+ return nil
+ }
+ case emptyList:
+ if list.desc == fd {
+ return nil
+ }
+ }
+ return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
+ default:
+ return singularTypeIsValid(fd, v)
+ }
+}
+
func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
+ if err := singularTypeIsValid(fd, v); err != nil {
+ panic(err)
+ }
+}
+
+func singularTypeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
vi := v.Interface()
var ok bool
switch fd.Kind() {
@@ -435,12 +482,16 @@
var m pref.Message
m, ok = vi.(pref.Message)
if ok && m.Descriptor().FullName() != fd.Message().FullName() {
- panic(errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName()))
+ return errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())
+ }
+ if dm, ok := vi.(*Message); ok && dm.known == nil {
+ return errors.New("%v: assigning invalid zero-value message", fd.FullName())
}
}
if !ok {
- panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
+ return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
}
+ return nil
}
func newListEntry(fd pref.FieldDescriptor) pref.Value {
@@ -470,3 +521,102 @@
}
panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
}
+
+// extensionType is a dynamic protoreflect.ExtensionType.
+type extensionType struct {
+ desc extensionTypeDescriptor
+}
+
+// NewExtensionType creates a new ExtensionType with the provided descriptor.
+//
+// Dynamic ExtensionTypes with the same descriptor compare as equal. That is,
+// if xd1 == xd2, then NewExtensionType(xd1) == NewExtensionType(xd2).
+//
+// The InterfaceOf and ValueOf methods of the extension type are defined as:
+//
+// func (xt extensionType) ValueOf(iv interface{}) protoreflect.Value {
+// return protoreflect.ValueOf(iv)
+// }
+//
+// func (xt extensionType) InterfaceOf(v protoreflect.Value) interface{} {
+// return v.Interface()
+// }
+//
+// The Go type used by the proto.GetExtension and proto.SetExtension functions
+// is determined by these methods, and is therefore equivalent to the Go type
+// used to represent a protoreflect.Value. See the protoreflect.Value
+// documentation for more details.
+func NewExtensionType(desc pref.ExtensionDescriptor) pref.ExtensionType {
+ if xt, ok := desc.(pref.ExtensionTypeDescriptor); ok {
+ desc = xt.Descriptor()
+ }
+ return extensionType{extensionTypeDescriptor{desc}}
+}
+
+func (xt extensionType) New() pref.Value {
+ switch {
+ case xt.desc.IsMap():
+ return pref.ValueOf(&dynamicMap{
+ desc: xt.desc,
+ mapv: make(map[interface{}]pref.Value),
+ })
+ case xt.desc.IsList():
+ return pref.ValueOf(&dynamicList{desc: xt.desc})
+ case xt.desc.Message() != nil:
+ return pref.ValueOf(New(xt.desc.Message()))
+ default:
+ return xt.desc.Default()
+ }
+}
+
+func (xt extensionType) Zero() pref.Value {
+ switch {
+ case xt.desc.IsMap():
+ return pref.ValueOf(&dynamicMap{desc: xt.desc})
+ case xt.desc.Cardinality() == pref.Repeated:
+ return pref.ValueOf(emptyList{desc: xt.desc})
+ case xt.desc.Message() != nil:
+ return pref.ValueOf(&Message{desc: xt.desc.Message()})
+ default:
+ return xt.desc.Default()
+ }
+}
+
+func (xt extensionType) GoType() reflect.Type {
+ return reflect.TypeOf(xt.InterfaceOf(xt.New()))
+}
+
+func (xt extensionType) TypeDescriptor() pref.ExtensionTypeDescriptor {
+ return xt.desc
+}
+
+func (xt extensionType) ValueOf(iv interface{}) pref.Value {
+ v := pref.ValueOf(iv)
+ typecheck(xt.desc, v)
+ return v
+}
+
+func (xt extensionType) InterfaceOf(v pref.Value) interface{} {
+ typecheck(xt.desc, v)
+ return v.Interface()
+}
+
+func (xt extensionType) IsValidInterface(iv interface{}) bool {
+ return typeIsValid(xt.desc, pref.ValueOf(iv)) == nil
+}
+
+func (xt extensionType) IsValidValue(v pref.Value) bool {
+ return typeIsValid(xt.desc, v) == nil
+}
+
+type extensionTypeDescriptor struct {
+ pref.ExtensionDescriptor
+}
+
+func (xt extensionTypeDescriptor) Type() pref.ExtensionType {
+ return extensionType{xt}
+}
+
+func (xt extensionTypeDescriptor) Descriptor() pref.ExtensionDescriptor {
+ return xt.ExtensionDescriptor
+}
diff --git a/types/dynamicpb/dynamic_test.go b/types/dynamicpb/dynamic_test.go
index 5a7db95..a4696d4 100644
--- a/types/dynamicpb/dynamic_test.go
+++ b/types/dynamicpb/dynamic_test.go
@@ -8,6 +8,8 @@
"testing"
"google.golang.org/protobuf/proto"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+ preg "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/testing/prototest"
"google.golang.org/protobuf/types/dynamicpb"
@@ -24,3 +26,37 @@
prototest.TestMessage(t, dynamicpb.New(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
}
}
+
+func TestDynamicExtensions(t *testing.T) {
+ file, err := preg.GlobalFiles.FindFileByPath("test/ext.proto")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
+ opts := prototest.MessageOptions{
+ Resolver: extResolver{},
+ }
+ for i := 0; i < file.Extensions().Len(); i++ {
+ opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
+ }
+ prototest.TestMessage(t, dynamicpb.New(md), opts)
+}
+
+type extResolver struct{}
+
+func (extResolver) FindExtensionByName(field pref.FullName) (pref.ExtensionType, error) {
+ xt, err := preg.GlobalTypes.FindExtensionByName(field)
+ if err != nil {
+ return nil, err
+ }
+ return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
+}
+
+func (extResolver) FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error) {
+ xt, err := preg.GlobalTypes.FindExtensionByNumber(message, field)
+ if err != nil {
+ return nil, err
+ }
+ return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
+}