testing/prototest: refactor prototest API

For consistency with other options types in the protobuf module, make
the test function a method of the options.

Drop the ExtensionTypes option and just look up the extension types to
test with in the provided resolver.

Change-Id: I7918bd10b7c003e4af56d27521d30218653d5b4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/219142
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index 0c6b700..4fccb70 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -17,47 +17,42 @@
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	preg "google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // TODO: Test invalid field descriptors or oneof descriptors.
 // TODO: This should test the functionality that can be provided by fast-paths.
 
-// MessageOptions configure message tests.
-type MessageOptions struct {
-	// ExtensionTypes is a list of types to test with.
-	//
-	// If nil, TestMessage will look for extension types in the global registry.
-	ExtensionTypes []pref.ExtensionType
-
-	// Resolver is used for looking up types when unmarshaling extension fields.
+// Message tests a message implemention.
+type Message struct {
+	// Resolver is used to determine the list of extension fields to test with.
 	// If nil, this defaults to using protoregistry.GlobalTypes.
 	Resolver interface {
-		preg.ExtensionTypeResolver
+		FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
+		FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
+		RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
 	}
 }
 
-// TODO(blocks): TestMessage should not take in MessageOptions,
-// but have a MessageOptions.Test method instead.
+// Test performs tests on a MessageType implementation.
+func (test Message) Test(t testing.TB, mt pref.MessageType) {
+	testType(t, mt)
 
-// TestMessage runs the provided m through a series of tests
-// exercising the protobuf reflection API.
-func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
-	testType(t, m)
-
-	md := m.ProtoReflect().Descriptor()
-	m1 := m.ProtoReflect().New()
+	md := mt.Descriptor()
+	m1 := mt.New()
 	for i := 0; i < md.Fields().Len(); i++ {
 		fd := md.Fields().Get(i)
 		testField(t, m1, fd)
 	}
-	if opts.ExtensionTypes == nil {
-		preg.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
-			opts.ExtensionTypes = append(opts.ExtensionTypes, e)
-			return true
-		})
+	if test.Resolver == nil {
+		test.Resolver = protoregistry.GlobalTypes
 	}
-	for _, xt := range opts.ExtensionTypes {
+	var extTypes []pref.ExtensionType
+	test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
+		extTypes = append(extTypes, e)
+		return true
+	})
+	for _, xt := range extTypes {
 		testField(t, m1, xt.TypeDescriptor())
 	}
 	for i := 0; i < md.Oneofs().Len(); i++ {
@@ -66,9 +61,9 @@
 	testUnknown(t, m1)
 
 	// Test round-trip marshal/unmarshal.
-	m2 := m.ProtoReflect().New().Interface()
+	m2 := mt.New().Interface()
 	populateMessage(m2.ProtoReflect(), 1, nil)
-	for _, xt := range opts.ExtensionTypes {
+	for _, xt := range extTypes {
 		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
 	}
 	b, err := proto.MarshalOptions{
@@ -77,10 +72,10 @@
 	if err != nil {
 		t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
 	}
-	m3 := m.ProtoReflect().New().Interface()
+	m3 := mt.New().Interface()
 	if err := (proto.UnmarshalOptions{
 		AllowPartial: true,
-		Resolver:     opts.Resolver,
+		Resolver:     test.Resolver,
 	}.Unmarshal(b, m3)); err != nil {
 		t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
 	}
@@ -89,7 +84,8 @@
 	}
 }
 
-func testType(t testing.TB, m proto.Message) {
+func testType(t testing.TB, mt pref.MessageType) {
+	m := mt.New().Interface()
 	want := reflect.TypeOf(m)
 	if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
 		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)
diff --git a/testing/prototest/prototest_test.go b/testing/prototest/prototest_test.go
index 44fa0bc..4307d44 100644
--- a/testing/prototest/prototest_test.go
+++ b/testing/prototest/prototest_test.go
@@ -38,7 +38,7 @@
 
 	for _, m := range ms {
 		t.Run(fmt.Sprintf("%T", m), func(t *testing.T) {
-			prototest.TestMessage(t, m, prototest.MessageOptions{})
+			prototest.Message{}.Test(t, m.ProtoReflect().Type())
 		})
 	}
 }
diff --git a/types/dynamicpb/dynamic_test.go b/types/dynamicpb/dynamic_test.go
index 281a01e..78e93bb 100644
--- a/types/dynamicpb/dynamic_test.go
+++ b/types/dynamicpb/dynamic_test.go
@@ -23,24 +23,20 @@
 		(*test3pb.TestAllTypes)(nil),
 		(*testpb.TestAllExtensions)(nil),
 	} {
-		prototest.TestMessage(t, dynamicpb.NewMessage(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
+		mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
+		prototest.Message{}.Test(t, mt)
 	}
 }
 
 func TestDynamicExtensions(t *testing.T) {
-	file, err := preg.GlobalFiles.FindFileByPath("internal/testprotos/test/ext.proto")
-	if err != nil {
-		t.Fatal(err)
+	for _, message := range []proto.Message{
+		(*testpb.TestAllExtensions)(nil),
+	} {
+		mt := dynamicpb.NewMessageType(message.ProtoReflect().Descriptor())
+		prototest.Message{
+			Resolver: extResolver{},
+		}.Test(t, mt)
 	}
-
-	md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
-	opts := prototest.MessageOptions{
-		Resolver: extResolver{},
-	}
-	for i := 0; i < file.Extensions().Len(); i++ {
-		opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
-	}
-	prototest.TestMessage(t, dynamicpb.NewMessage(md), opts)
 }
 
 type extResolver struct{}
@@ -60,3 +56,9 @@
 	}
 	return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
 }
+
+func (extResolver) RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool) {
+	preg.GlobalTypes.RangeExtensionsByMessage(message, func(xt pref.ExtensionType) bool {
+		return f(dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()))
+	})
+}