proto: make one test more general
Tweak the "nested unknown extension" test case's resolver to not depend
on the exact message being tested. Useful for if/when we want to run
these tests on other message implementations.
Change-Id: Id1722afd8e094ddb59cb3e5440f7994c20cfa681
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/217760
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 045cb8f..d980fd1 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -10,6 +10,7 @@
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/internal/protobuild"
"google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
@@ -1389,12 +1390,17 @@
desc: "nested unknown extension",
unmarshalOptions: proto.UnmarshalOptions{
DiscardUnknown: true,
- Resolver: func() protoregistry.ExtensionTypeResolver {
- types := &protoregistry.Types{}
- types.RegisterExtension(testpb.E_OptionalNestedMessage)
- types.RegisterExtension(testpb.E_OptionalInt32)
- return types
- }(),
+ Resolver: filterResolver{
+ filter: func(name protoreflect.FullName) bool {
+ switch name.Name() {
+ case "optional_nested_message",
+ "optional_int32":
+ return true
+ }
+ return false
+ },
+ resolver: protoregistry.GlobalTypes,
+ },
},
decodeTo: makeMessages(protobuild.Message{
"optional_nested_message": protobuild.Message{
@@ -1847,3 +1853,26 @@
}.Marshal(),
},
}
+
+type filterResolver struct {
+ filter func(name protoreflect.FullName) bool
+ resolver protoregistry.ExtensionTypeResolver
+}
+
+func (f filterResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
+ if !f.filter(field) {
+ return nil, protoregistry.NotFound
+ }
+ return f.resolver.FindExtensionByName(field)
+}
+
+func (f filterResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
+ xt, err := f.resolver.FindExtensionByNumber(message, field)
+ if err != nil {
+ return nil, err
+ }
+ if !f.filter(xt.TypeDescriptor().FullName()) {
+ return nil, protoregistry.NotFound
+ }
+ return xt, nil
+}