proto, encoding/protojson, encoding/prototext: use Resolver interface

Instead of accepting a concrete protoregistry.Types type,
accept an interface that provides the necessary functionality
to perform the serialization.

The advantages of this approach:
* There is no need for complex logic to allow a Parent or custom
Resolver on the protoregistry.Types type.
* Users can pass their own custom resolver implementations directly
to the serialization functions.
* This is a more principled approach to plumbing custom resolvers
than the previous approach of overloading behavior on the concrete
Types type.

The disadvantages of this approach:
* A pointer to a concrete type is 8B, while an interface is 16B.
However, the expansion of the {Marshal,Unmarshal}Options structs
should be a concern solved separately from how to plumb custom resolvers.
* The resolver interfaces as defined today may be insufficient to
provide functionality needed in the future if protobuf expands its
feature set. For example, let's suppose the Any message permits
directly representing a enum by name. This would require the ability
to lookup an enum by name. To support that hypothetical need,
we can document that the serializers type-assert the provided Resolver
to a EnumTypeResolver and use that if possible. There is some loss
of type safety with this approach, but provides a clear path forward.

Change-Id: I81ca80e59335d36be6b43d57ec8e17abfdfa3bad
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/177044
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go
index 99e1a8b..a40d1e2 100644
--- a/encoding/protojson/decode.go
+++ b/encoding/protojson/decode.go
@@ -36,10 +36,13 @@
 	// If DiscardUnknown is set, unknown fields are ignored.
 	DiscardUnknown bool
 
-	// Resolver is the registry used for type lookups when unmarshaling extensions
-	// and processing Any. If Resolver is not set, unmarshaling will default to
-	// using protoregistry.GlobalTypes.
-	Resolver *protoregistry.Types
+	// Resolver is used for looking up types when unmarshaling
+	// google.protobuf.Any messages or extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		protoregistry.MessageTypeResolver
+		protoregistry.ExtensionTypeResolver
+	}
 
 	decoder *json.Decoder
 }
diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go
index bd0ee3d..d789986 100644
--- a/encoding/protojson/encode.go
+++ b/encoding/protojson/encode.go
@@ -36,10 +36,11 @@
 	// composed of space or tab characters.
 	Indent string
 
-	// Resolver is the registry used for type lookups when marshaling
-	// google.protobuf.Any messages. If Resolver is not set, marshaling will
-	// default to using protoregistry.GlobalTypes.
-	Resolver *protoregistry.Types
+	// Resolver is used for looking up types when expanding google.protobuf.Any
+	// messages. If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		protoregistry.MessageTypeResolver
+	}
 
 	encoder *json.Encoder
 }
diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go
index efc4c7a..20bdfe6 100644
--- a/encoding/prototext/decode.go
+++ b/encoding/prototext/decode.go
@@ -33,10 +33,13 @@
 	// return error if there are any missing required fields.
 	AllowPartial bool
 
-	// Resolver is the registry used for type lookups when unmarshaling extensions
-	// and processing Any. If Resolver is not set, unmarshaling will default to
-	// using protoregistry.GlobalTypes.
-	Resolver *protoregistry.Types
+	// Resolver is used for looking up types when unmarshaling
+	// google.protobuf.Any messages or extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		protoregistry.MessageTypeResolver
+		protoregistry.ExtensionTypeResolver
+	}
 }
 
 // Unmarshal reads the given []byte and populates the given proto.Message using options in
diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go
index 8f06370..d86492e 100644
--- a/encoding/prototext/encode.go
+++ b/encoding/prototext/encode.go
@@ -39,11 +39,11 @@
 	// composed of space or tab characters.
 	Indent string
 
-	// Resolver is the registry used for type lookups when marshaling out
-	// google.protobuf.Any messages in expanded form. If Resolver is not set,
-	// marshaling will default to using protoregistry.GlobalTypes.  If a type is
-	// not found, an Any message will be marshaled as a regular message.
-	Resolver *protoregistry.Types
+	// Resolver is used for looking up types when expanding google.protobuf.Any
+	// messages. If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		protoregistry.MessageTypeResolver
+	}
 }
 
 // Marshal writes the given proto.Message in textproto format using options in MarshalOptions object.
diff --git a/proto/decode.go b/proto/decode.go
index 3d3d0da..b376685 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -28,7 +28,9 @@
 
 	// Resolver is used for looking up types when unmarshaling extension fields.
 	// If nil, this defaults to using protoregistry.GlobalTypes.
-	Resolver *protoregistry.Types
+	Resolver interface {
+		protoregistry.ExtensionTypeResolver
+	}
 
 	pragma.NoUnkeyedLiterals
 }
diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go
index d3ea3fc..b778cd2 100644
--- a/reflect/protoregistry/registry.go
+++ b/reflect/protoregistry/registry.go
@@ -276,38 +276,56 @@
 	_ Type = protoreflect.ExtensionType(nil)
 )
 
+// MessageTypeResolver is an interface for looking up messages.
+//
+// A compliant implementation must deterministically return the same type
+// if no error is encountered.
+//
+// The Types type implements this interface.
+type MessageTypeResolver interface {
+	// FindMessageByName looks up a message by its full name.
+	// E.g., "google.protobuf.Any"
+	//
+	// This return (nil, NotFound) if not found.
+	FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error)
+
+	// FindMessageByURL looks up a message by a URL identifier.
+	// See documentation on google.protobuf.Any.type_url for the URL format.
+	//
+	// This returns (nil, NotFound) if not found.
+	FindMessageByURL(url string) (protoreflect.MessageType, error)
+}
+
+// ExtensionTypeResolver is an interface for looking up extensions.
+//
+// A compliant implementation must deterministically return the same type
+// if no error is encountered.
+//
+// The Types type implements this interface.
+type ExtensionTypeResolver interface {
+	// FindExtensionByName looks up a extension field by the field's full name.
+	// Note that this is the full name of the field as determined by
+	// where the extension is declared and is unrelated to the full name of the
+	// message being extended.
+	//
+	// This returns (nil, NotFound) if not found.
+	FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
+
+	// FindExtensionByNumber looks up a extension field by the field number
+	// within some parent message, identified by full name.
+	//
+	// This returns (nil, NotFound) if not found.
+	FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
+}
+
+var (
+	_ MessageTypeResolver   = (*Types)(nil)
+	_ ExtensionTypeResolver = (*Types)(nil)
+)
+
 // Types is a registry for looking up or iterating over descriptor types.
 // The Find and Range methods are safe for concurrent use.
 type Types struct {
-	// Parent sets the parent registry to consult if a find operation
-	// could not locate the appropriate entry.
-	//
-	// Setting a parent results in each Range operation also iterating over the
-	// entries contained within the parent. In such a case, it is possible for
-	// Range to emit duplicates (since they may exist in both child and parent).
-	// Range iteration is guaranteed to iterate over local entries before
-	// iterating over parent entries.
-	Parent *Types
-
-	// Resolver sets the local resolver to consult if the local registry does
-	// not contain an entry. The resolver takes precedence over the parent.
-	//
-	// The url is a URL where the full name of the type is the last segment
-	// of the path (i.e. string following the last '/' character).
-	// When missing a '/' character, the URL is the full name of the type.
-	// See documentation on the google.protobuf.Any.type_url field for details.
-	//
-	// If the resolver returns a result, it is not automatically registered
-	// into the local registry. Thus, a resolver function should cache results
-	// such that it deterministically returns the same result given the
-	// same URL assuming the error returned is nil or NotFound.
-	//
-	// If the resolver returns the NotFound error, the registry will consult the
-	// parent registry if it is set.
-	//
-	// Setting a resolver has no effect on the result of each Range operation.
-	Resolver func(url string) (Type, error)
-
 	// TODO: The syntax of the URL is ill-defined and the protobuf team recently
 	// changed the documented semantics in a way that breaks prior usages.
 	// I do not believe they can do this and need to sync up with the
@@ -342,7 +360,6 @@
 // NewTypes returns a registry initialized with the provided set of types.
 // If there are conflicts, the first one takes precedence.
 func NewTypes(typs ...Type) *Types {
-	// TODO: Allow setting resolver and parent via constructor?
 	r := new(Types)
 	r.Register(typs...) // ignore errors; first takes precedence
 	return r
@@ -418,25 +435,17 @@
 //
 // This returns (nil, NotFound) if not found.
 func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) {
-	r.globalCheck()
 	if r == nil {
 		return nil, NotFound
 	}
 	v, _ := r.typesByName[enum]
-	if v == nil && r.Resolver != nil {
-		var err error
-		v, err = r.Resolver(string(enum))
-		if err != nil && err != NotFound {
-			return nil, err
-		}
-	}
 	if v != nil {
 		if et, _ := v.(protoreflect.EnumType); et != nil {
 			return et, nil
 		}
 		return nil, errors.New("found wrong type: got %v, want enum", typeName(v))
 	}
-	return r.Parent.FindEnumByName(enum)
+	return nil, NotFound
 }
 
 // FindMessageByName looks up a message by its full name.
@@ -449,11 +458,10 @@
 }
 
 // FindMessageByURL looks up a message by a URL identifier.
-// See Resolver for the format of the URL.
+// See documentation on google.protobuf.Any.type_url for the URL format.
 //
 // This returns (nil, NotFound) if not found.
 func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) {
-	r.globalCheck()
 	if r == nil {
 		return nil, NotFound
 	}
@@ -463,20 +471,13 @@
 	}
 
 	v, _ := r.typesByName[message]
-	if v == nil && r.Resolver != nil {
-		var err error
-		v, err = r.Resolver(url)
-		if err != nil && err != NotFound {
-			return nil, err
-		}
-	}
 	if v != nil {
 		if mt, _ := v.(protoreflect.MessageType); mt != nil {
 			return mt, nil
 		}
 		return nil, errors.New("found wrong type: got %v, want message", typeName(v))
 	}
-	return r.Parent.FindMessageByURL(url)
+	return nil, NotFound
 }
 
 // FindExtensionByName looks up a extension field by the field's full name.
@@ -486,25 +487,17 @@
 //
 // This returns (nil, NotFound) if not found.
 func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
-	r.globalCheck()
 	if r == nil {
 		return nil, NotFound
 	}
 	v, _ := r.typesByName[field]
-	if v == nil && r.Resolver != nil {
-		var err error
-		v, err = r.Resolver(string(field))
-		if err != nil && err != NotFound {
-			return nil, err
-		}
-	}
 	if v != nil {
 		if xt, _ := v.(protoreflect.ExtensionType); xt != nil {
 			return xt, nil
 		}
 		return nil, errors.New("found wrong type: got %v, want extension", typeName(v))
 	}
-	return r.Parent.FindExtensionByName(field)
+	return nil, NotFound
 }
 
 // FindExtensionByNumber looks up a extension field by the field number
@@ -512,20 +505,18 @@
 //
 // This returns (nil, NotFound) if not found.
 func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
-	r.globalCheck()
 	if r == nil {
 		return nil, NotFound
 	}
 	if xt, ok := r.extensionsByMessage[message][field]; ok {
 		return xt, nil
 	}
-	return r.Parent.FindExtensionByNumber(message, field)
+	return nil, NotFound
 }
 
 // RangeEnums iterates over all registered enums.
 // Iteration order is undefined.
 func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
-	r.globalCheck()
 	if r == nil {
 		return
 	}
@@ -536,13 +527,11 @@
 			}
 		}
 	}
-	r.Parent.RangeEnums(f)
 }
 
 // RangeMessages iterates over all registered messages.
 // Iteration order is undefined.
 func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
-	r.globalCheck()
 	if r == nil {
 		return
 	}
@@ -553,13 +542,11 @@
 			}
 		}
 	}
-	r.Parent.RangeMessages(f)
 }
 
 // RangeExtensions iterates over all registered extensions.
 // Iteration order is undefined.
 func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
-	r.globalCheck()
 	if r == nil {
 		return
 	}
@@ -570,13 +557,11 @@
 			}
 		}
 	}
-	r.Parent.RangeExtensions(f)
 }
 
 // RangeExtensionsByMessage iterates over all registered extensions filtered
 // by a given message type. Iteration order is undefined.
 func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
-	r.globalCheck()
 	if r == nil {
 		return
 	}
@@ -585,13 +570,6 @@
 			return
 		}
 	}
-	r.Parent.RangeExtensionsByMessage(message, f)
-}
-
-func (r *Types) globalCheck() {
-	if r == GlobalTypes && (r.Parent != nil || r.Resolver != nil) {
-		panic("GlobalTypes.Parent and GlobalTypes.Resolver cannot be set")
-	}
 }
 
 func typeName(t Type) string {
diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go
index 5da8f3d..055879a 100644
--- a/reflect/protoregistry/registry_test.go
+++ b/reflect/protoregistry/registry_test.go
@@ -302,41 +302,12 @@
 }
 
 func TestTypes(t *testing.T) {
-	// Suffix 1 in registry, 2 in parent, 3 in resolver.
 	mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{})
-	mt2 := pimpl.Export{}.MessageTypeOf(&testpb.Message2{})
-	mt3 := pimpl.Export{}.MessageTypeOf(&testpb.Message3{})
 	et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE)
-	et2 := pimpl.Export{}.EnumTypeOf(testpb.Enum2_UNO)
-	et3 := pimpl.Export{}.EnumTypeOf(testpb.Enum3_YI)
-	// Suffix indicates field number.
-	xt11 := testpb.E_StringField.Type
-	xt12 := testpb.E_EnumField.Type
-	xt13 := testpb.E_MessageField.Type
-	xt21 := testpb.E_Message4_MessageField.Type
-	xt22 := testpb.E_Message4_EnumField.Type
-	xt23 := testpb.E_Message4_StringField.Type
-	parent := &preg.Types{}
-	if err := parent.Register(mt2, et2, xt12, xt22); err != nil {
-		t.Fatalf("parent.Register() returns unexpected error: %v", err)
-	}
-	registry := &preg.Types{
-		Parent: parent,
-		Resolver: func(url string) (preg.Type, error) {
-			switch {
-			case strings.HasSuffix(url, "testprotos.Message3"):
-				return mt3, nil
-			case strings.HasSuffix(url, "testprotos.Enum3"):
-				return et3, nil
-			case strings.HasSuffix(url, "testprotos.message_field"):
-				return xt13, nil
-			case strings.HasSuffix(url, "testprotos.Message4.string_field"):
-				return xt23, nil
-			}
-			return nil, preg.NotFound
-		},
-	}
-	if err := registry.Register(mt1, et1, xt11, xt21); err != nil {
+	xt1 := testpb.E_StringField.Type
+	xt2 := testpb.E_Message4_MessageField.Type
+	registry := new(preg.Types)
+	if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
 		t.Fatalf("registry.Register() returns unexpected error: %v", err)
 	}
 
@@ -350,12 +321,6 @@
 			name:        "testprotos.Message1",
 			messageType: mt1,
 		}, {
-			name:        "testprotos.Message2",
-			messageType: mt2,
-		}, {
-			name:        "testprotos.Message3",
-			messageType: mt3,
-		}, {
 			name:         "testprotos.NoSuchMessage",
 			wantErr:      true,
 			wantNotFound: true,
@@ -396,12 +361,6 @@
 			name:        "testprotos.Message1",
 			messageType: mt1,
 		}, {
-			name:        "foo.com/testprotos.Message2",
-			messageType: mt2,
-		}, {
-			name:        "/testprotos.Message3",
-			messageType: mt3,
-		}, {
 			name:         "type.googleapis.com/testprotos.Nada",
 			wantErr:      true,
 			wantNotFound: true,
@@ -436,12 +395,6 @@
 			name:     "testprotos.Enum1",
 			enumType: et1,
 		}, {
-			name:     "testprotos.Enum2",
-			enumType: et2,
-		}, {
-			name:     "testprotos.Enum3",
-			enumType: et3,
-		}, {
 			name:         "testprotos.None",
 			wantErr:      true,
 			wantNotFound: true,
@@ -474,22 +427,10 @@
 			wantNotFound  bool
 		}{{
 			name:          "testprotos.string_field",
-			extensionType: xt11,
-		}, {
-			name:          "testprotos.enum_field",
-			extensionType: xt12,
-		}, {
-			name:          "testprotos.message_field",
-			extensionType: xt13,
+			extensionType: xt1,
 		}, {
 			name:          "testprotos.Message4.message_field",
-			extensionType: xt21,
-		}, {
-			name:          "testprotos.Message4.enum_field",
-			extensionType: xt22,
-		}, {
-			name:          "testprotos.Message4.string_field",
-			extensionType: xt23,
+			extensionType: xt2,
 		}, {
 			name:         "testprotos.None",
 			wantErr:      true,
@@ -525,13 +466,8 @@
 		}{{
 			parent:        "testprotos.Message1",
 			number:        11,
-			extensionType: xt11,
+			extensionType: xt1,
 		}, {
-			parent:        "testprotos.Message1",
-			number:        12,
-			extensionType: xt12,
-		}, {
-			// FindExtensionByNumber does not use Resolver.
 			parent:       "testprotos.Message1",
 			number:       13,
 			wantErr:      true,
@@ -539,13 +475,8 @@
 		}, {
 			parent:        "testprotos.Message1",
 			number:        21,
-			extensionType: xt21,
+			extensionType: xt2,
 		}, {
-			parent:        "testprotos.Message1",
-			number:        22,
-			extensionType: xt22,
-		}, {
-			// FindExtensionByNumber does not use Resolver.
 			parent:       "testprotos.Message1",
 			number:       23,
 			wantErr:      true,
@@ -603,8 +534,7 @@
 	})
 
 	t.Run("RangeMessages", func(t *testing.T) {
-		// RangeMessages do not include messages from Resolver.
-		want := []preg.Type{mt1, mt2}
+		want := []preg.Type{mt1}
 		var got []preg.Type
 		registry.RangeMessages(func(mt pref.MessageType) bool {
 			got = append(got, mt)
@@ -618,8 +548,7 @@
 	})
 
 	t.Run("RangeEnums", func(t *testing.T) {
-		// RangeEnums do not include enums from Resolver.
-		want := []preg.Type{et1, et2}
+		want := []preg.Type{et1}
 		var got []preg.Type
 		registry.RangeEnums(func(et pref.EnumType) bool {
 			got = append(got, et)
@@ -633,8 +562,7 @@
 	})
 
 	t.Run("RangeExtensions", func(t *testing.T) {
-		// RangeExtensions do not include messages from Resolver.
-		want := []preg.Type{xt11, xt12, xt21, xt22}
+		want := []preg.Type{xt1, xt2}
 		var got []preg.Type
 		registry.RangeExtensions(func(xt pref.ExtensionType) bool {
 			got = append(got, xt)
@@ -648,8 +576,7 @@
 	})
 
 	t.Run("RangeExtensionsByMessage", func(t *testing.T) {
-		// RangeExtensions do not include messages from Resolver.
-		want := []preg.Type{xt11, xt12, xt21, xt22}
+		want := []preg.Type{xt1, xt2}
 		var got []preg.Type
 		registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool {
 			got = append(got, xt)
diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go
index 0ba9d5c..af86d62 100644
--- a/runtime/protoiface/methods.go
+++ b/runtime/protoiface/methods.go
@@ -63,7 +63,9 @@
 type UnmarshalOptions struct {
 	AllowPartial   bool
 	DiscardUnknown bool
-	Resolver       *protoregistry.Types
+	Resolver       interface {
+		protoregistry.ExtensionTypeResolver
+	}
 
 	pragma.NoUnkeyedLiterals
 }