testing/prototest: refactor prototest API
For consistency with other options types in the protobuf module, make
the test function a method of the options.
Drop the ExtensionTypes option and just look up the extension types to
test with in the provided resolver.
Change-Id: I7918bd10b7c003e4af56d27521d30218653d5b4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/219142
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index 0c6b700..4fccb70 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -17,47 +17,42 @@
"google.golang.org/protobuf/internal/encoding/wire"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
- preg "google.golang.org/protobuf/reflect/protoregistry"
+ "google.golang.org/protobuf/reflect/protoregistry"
)
// TODO: Test invalid field descriptors or oneof descriptors.
// TODO: This should test the functionality that can be provided by fast-paths.
-// MessageOptions configure message tests.
-type MessageOptions struct {
- // ExtensionTypes is a list of types to test with.
- //
- // If nil, TestMessage will look for extension types in the global registry.
- ExtensionTypes []pref.ExtensionType
-
- // Resolver is used for looking up types when unmarshaling extension fields.
+// Message tests a message implemention.
+type Message struct {
+ // Resolver is used to determine the list of extension fields to test with.
// If nil, this defaults to using protoregistry.GlobalTypes.
Resolver interface {
- preg.ExtensionTypeResolver
+ FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
+ FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
+ RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
}
}
-// TODO(blocks): TestMessage should not take in MessageOptions,
-// but have a MessageOptions.Test method instead.
+// Test performs tests on a MessageType implementation.
+func (test Message) Test(t testing.TB, mt pref.MessageType) {
+ testType(t, mt)
-// TestMessage runs the provided m through a series of tests
-// exercising the protobuf reflection API.
-func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
- testType(t, m)
-
- md := m.ProtoReflect().Descriptor()
- m1 := m.ProtoReflect().New()
+ md := mt.Descriptor()
+ m1 := mt.New()
for i := 0; i < md.Fields().Len(); i++ {
fd := md.Fields().Get(i)
testField(t, m1, fd)
}
- if opts.ExtensionTypes == nil {
- preg.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
- opts.ExtensionTypes = append(opts.ExtensionTypes, e)
- return true
- })
+ if test.Resolver == nil {
+ test.Resolver = protoregistry.GlobalTypes
}
- for _, xt := range opts.ExtensionTypes {
+ var extTypes []pref.ExtensionType
+ test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
+ extTypes = append(extTypes, e)
+ return true
+ })
+ for _, xt := range extTypes {
testField(t, m1, xt.TypeDescriptor())
}
for i := 0; i < md.Oneofs().Len(); i++ {
@@ -66,9 +61,9 @@
testUnknown(t, m1)
// Test round-trip marshal/unmarshal.
- m2 := m.ProtoReflect().New().Interface()
+ m2 := mt.New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil)
- for _, xt := range opts.ExtensionTypes {
+ for _, xt := range extTypes {
m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
}
b, err := proto.MarshalOptions{
@@ -77,10 +72,10 @@
if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
}
- m3 := m.ProtoReflect().New().Interface()
+ m3 := mt.New().Interface()
if err := (proto.UnmarshalOptions{
AllowPartial: true,
- Resolver: opts.Resolver,
+ Resolver: test.Resolver,
}.Unmarshal(b, m3)); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
}
@@ -89,7 +84,8 @@
}
}
-func testType(t testing.TB, m proto.Message) {
+func testType(t testing.TB, mt pref.MessageType) {
+ m := mt.New().Interface()
want := reflect.TypeOf(m)
if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)
diff --git a/testing/prototest/prototest_test.go b/testing/prototest/prototest_test.go
index 44fa0bc..4307d44 100644
--- a/testing/prototest/prototest_test.go
+++ b/testing/prototest/prototest_test.go
@@ -38,7 +38,7 @@
for _, m := range ms {
t.Run(fmt.Sprintf("%T", m), func(t *testing.T) {
- prototest.TestMessage(t, m, prototest.MessageOptions{})
+ prototest.Message{}.Test(t, m.ProtoReflect().Type())
})
}
}
diff --git a/types/dynamicpb/dynamic_test.go b/types/dynamicpb/dynamic_test.go
index 281a01e..78e93bb 100644
--- a/types/dynamicpb/dynamic_test.go
+++ b/types/dynamicpb/dynamic_test.go
@@ -23,24 +23,20 @@
(*test3pb.TestAllTypes)(nil),
(*testpb.TestAllExtensions)(nil),
} {
- prototest.TestMessage(t, dynamicpb.NewMessage(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
+ mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
+ prototest.Message{}.Test(t, mt)
}
}
func TestDynamicExtensions(t *testing.T) {
- file, err := preg.GlobalFiles.FindFileByPath("internal/testprotos/test/ext.proto")
- if err != nil {
- t.Fatal(err)
+ for _, message := range []proto.Message{
+ (*testpb.TestAllExtensions)(nil),
+ } {
+ mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
+ prototest.Message{
+ Resolver: extResolver{},
+ }.Test(t, mt)
}
-
- md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
- opts := prototest.MessageOptions{
- Resolver: extResolver{},
- }
- for i := 0; i < file.Extensions().Len(); i++ {
- opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
- }
- prototest.TestMessage(t, dynamicpb.NewMessage(md), opts)
}
type extResolver struct{}
@@ -60,3 +56,9 @@
}
return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
}
+
+func (extResolver) RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool) {
+ preg.GlobalTypes.RangeExtensionsByMessage(message, func(xt pref.ExtensionType) bool {
+ return f(dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()))
+ })
+}