reflect/protoregistry: centralize MessageSet extension resolution logic

Centralize the MessageSet extension resolution logic in the registry.
This avoids needless replication of this exact logic in multiple places
(for JSON and text) and elsewhere.

Change-Id: I70bfea899e295e8c589f418965bf0dd099f93628
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/240077
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index 9bf4e8c..5ba9ebf 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -170,7 +170,7 @@
 		if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
 			// Only extension names are in [name] format.
 			extName := pref.FullName(name[1 : len(name)-1])
-			extType, err := d.findExtension(extName)
+			extType, err := d.opts.Resolver.FindExtensionByName(extName)
 			if err != nil && err != protoregistry.NotFound {
 				return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
 			}
@@ -257,15 +257,6 @@
 	}
 }
 
-// findExtension returns protoreflect.ExtensionType from the resolver if found.
-func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
-	xt, err := d.opts.Resolver.FindExtensionByName(xtName)
-	if err == nil {
-		return xt, nil
-	}
-	return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
-}
-
 func isKnownValue(fd pref.FieldDescriptor) bool {
 	md := fd.Message()
 	return md != nil && md.FullName() == genid.Value_message_fullname
diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go
index 5e3dcff..4791f65 100644
--- a/encoding/protojson/decode_test.go
+++ b/encoding/protojson/decode_test.go
@@ -1403,7 +1403,7 @@
     "optString": "not a messageset extension"
   }
 }`,
-		wantErr: `unknown field "[pb2.FakeMessageSetExtension]"`,
+		wantErr: `unable to resolve "[pb2.FakeMessageSetExtension]": found wrong type`,
 		skip:    !flags.ProtoLegacy,
 	}, {
 		desc:         "not real MessageSet 3",
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index cab95a4..8cce1e0 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -172,7 +172,7 @@
 
 		case text.TypeName:
 			// Handle extensions only. This code path is not for Any.
-			xt, xtErr = d.findExtension(pref.FullName(tok.TypeName()))
+			xt, xtErr = d.opts.Resolver.FindExtensionByName(pref.FullName(tok.TypeName()))
 
 		case text.FieldNumber:
 			isFieldNumberName = true
@@ -269,15 +269,6 @@
 	return nil
 }
 
-// findExtension returns protoreflect.ExtensionType from the Resolver if found.
-func (d decoder) findExtension(xtName pref.FullName) (pref.ExtensionType, error) {
-	xt, err := d.opts.Resolver.FindExtensionByName(xtName)
-	if err == nil {
-		return xt, nil
-	}
-	return messageset.FindMessageSetExtension(d.opts.Resolver, xtName)
-}
-
 // unmarshalSingular unmarshals a non-repeated field value specified by the
 // given FieldDescriptor.
 func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, m pref.Message) error {
diff --git a/encoding/prototext/decode_test.go b/encoding/prototext/decode_test.go
index dceded1..441dcb2 100644
--- a/encoding/prototext/decode_test.go
+++ b/encoding/prototext/decode_test.go
@@ -1508,7 +1508,7 @@
   opt_string: "not a messageset extension"
 }
 `,
-		wantErr: "unknown field: [pb2.FakeMessageSetExtension]",
+		wantErr: `unable to resolve [[pb2.FakeMessageSetExtension]]: found wrong type`,
 		skip:    !flags.ProtoLegacy,
 	}, {
 		desc:         "not real MessageSet 3",
diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go
index b1eeea5..453a81a 100644
--- a/internal/encoding/messageset/messageset.go
+++ b/internal/encoding/messageset/messageset.go
@@ -11,7 +11,6 @@
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/internal/errors"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	preg "google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // The MessageSet wire format is equivalent to a message defiend as follows,
@@ -48,33 +47,17 @@
 	return ok && xmd.IsMessageSet()
 }
 
-// IsMessageSetExtension reports this field extends a MessageSet.
+// IsMessageSetExtension reports this field properly extends a MessageSet.
 func IsMessageSetExtension(fd pref.FieldDescriptor) bool {
-	if fd.Name() != ExtensionName {
+	switch {
+	case fd.Name() != ExtensionName:
+		return false
+	case !IsMessageSet(fd.ContainingMessage()):
+		return false
+	case fd.FullName().Parent() != fd.Message().FullName():
 		return false
 	}
-	if fd.FullName().Parent() != fd.Message().FullName() {
-		return false
-	}
-	return IsMessageSet(fd.ContainingMessage())
-}
-
-// FindMessageSetExtension locates a MessageSet extension field by name.
-// In text and JSON formats, the extension name used is the message itself.
-// The extension field name is derived by appending ExtensionName.
-func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pref.ExtensionType, error) {
-	name := s.Append(ExtensionName)
-	xt, err := r.FindExtensionByName(name)
-	if err != nil {
-		if err == preg.NotFound {
-			return nil, err
-		}
-		return nil, errors.Wrap(err, "%q", name)
-	}
-	if !IsMessageSetExtension(xt.TypeDescriptor()) {
-		return nil, preg.NotFound
-	}
-	return xt, nil
+	return true
 }
 
 // SizeField returns the size of a MessageSet item field containing an extension
diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go
index 5e5f967..7dcf4d9 100644
--- a/reflect/protoregistry/registry.go
+++ b/reflect/protoregistry/registry.go
@@ -21,7 +21,9 @@
 	"strings"
 	"sync"
 
+	"google.golang.org/protobuf/internal/encoding/messageset"
 	"google.golang.org/protobuf/internal/errors"
+	"google.golang.org/protobuf/internal/flags"
 	"google.golang.org/protobuf/reflect/protoreflect"
 )
 
@@ -613,6 +615,26 @@
 		if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
 			return xt, nil
 		}
+
+		// MessageSet extensions are special in that the name of the extension
+		// is the name of the message type used to extend the MessageSet.
+		// This naming scheme is used by text and JSON serialization.
+		//
+		// This feature is protected by the ProtoLegacy flag since MessageSets
+		// are a proto1 feature that is long deprecated.
+		if flags.ProtoLegacy {
+			if _, ok := v.(protoreflect.MessageType); ok {
+				field := field.Append(messageset.ExtensionName)
+				if v := r.typesByName[field]; v != nil {
+					if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
+						if messageset.IsMessageSetExtension(xt.TypeDescriptor()) {
+							return xt, nil
+						}
+					}
+				}
+			}
+		}
+
 		return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
 	}
 	return nil, NotFound