proto: make one test more general

Tweak the "nested unknown extension" test case's resolver to not depend
on the exact message being tested. Useful for if/when we want to run
these tests on other message implementations.

Change-Id: Id1722afd8e094ddb59cb3e5440f7994c20cfa681
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217760
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 045cb8f..d980fd1 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -10,6 +10,7 @@
 	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/internal/protobuild"
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
 
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
@@ -1389,12 +1390,17 @@
 		desc: "nested unknown extension",
 		unmarshalOptions: proto.UnmarshalOptions{
 			DiscardUnknown: true,
-			Resolver: func() protoregistry.ExtensionTypeResolver {
-				types := &protoregistry.Types{}
-				types.RegisterExtension(testpb.E_OptionalNestedMessage)
-				types.RegisterExtension(testpb.E_OptionalInt32)
-				return types
-			}(),
+			Resolver: filterResolver{
+				filter: func(name protoreflect.FullName) bool {
+					switch name.Name() {
+					case "optional_nested_message",
+						"optional_int32":
+						return true
+					}
+					return false
+				},
+				resolver: protoregistry.GlobalTypes,
+			},
 		},
 		decodeTo: makeMessages(protobuild.Message{
 			"optional_nested_message": protobuild.Message{
@@ -1847,3 +1853,26 @@
 		}.Marshal(),
 	},
 }
+
+type filterResolver struct {
+	filter   func(name protoreflect.FullName) bool
+	resolver protoregistry.ExtensionTypeResolver
+}
+
+func (f filterResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
+	if !f.filter(field) {
+		return nil, protoregistry.NotFound
+	}
+	return f.resolver.FindExtensionByName(field)
+}
+
+func (f filterResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
+	xt, err := f.resolver.FindExtensionByNumber(message, field)
+	if err != nil {
+		return nil, err
+	}
+	if !f.filter(xt.TypeDescriptor().FullName()) {
+		return nil, protoregistry.NotFound
+	}
+	return xt, nil
+}