reflect/protoreflect: add MessageFieldTypes

The MessageFieldTypes interface (if implemented by a MessageType)
provides Go type information about the fields if they are
an enum or message type.

Change-Id: I68b20f5726377f6b0f2c20a8b6e45f9802b43f67
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/236777
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/api_export.go b/internal/impl/api_export.go
index b597452..abee5f3 100644
--- a/internal/impl/api_export.go
+++ b/internal/impl/api_export.go
@@ -167,7 +167,7 @@
 	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv.ProtoReflect().Type()
 	}
-	return legacyLoadMessageInfo(reflect.TypeOf(m), "")
+	return legacyLoadMessageType(reflect.TypeOf(m), "")
 }
 
 // MessageStringOf returns the message value as a string,
diff --git a/internal/impl/legacy_export.go b/internal/impl/legacy_export.go
index c3d741c..e3fb0b5 100644
--- a/internal/impl/legacy_export.go
+++ b/internal/impl/legacy_export.go
@@ -30,7 +30,7 @@
 	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv.ProtoReflect().Type()
 	}
-	return legacyLoadMessageInfo(reflect.TypeOf(m), name)
+	return legacyLoadMessageType(reflect.TypeOf(m), name)
 }
 
 // UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input.
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index 06c68e1..d5347cc 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -32,6 +32,16 @@
 	return mt.MessageOf(v.Interface())
 }
 
+// legacyLoadMessageType dynamically loads a protoreflect.Type for t,
+// where t must be not implement the v2 API already.
+// The provided name is used if it cannot be determined from the message.
+func legacyLoadMessageType(t reflect.Type, name pref.FullName) protoreflect.MessageType {
+	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct {
+		return aberrantMessageType{t}
+	}
+	return legacyLoadMessageInfo(t, name)
+}
+
 var legacyMessageTypeCache sync.Map // map[reflect.Type]*MessageInfo
 
 // legacyLoadMessageInfo dynamically loads a *MessageInfo for t,
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c026a98..c8c3859 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -15,6 +15,7 @@
 	"google.golang.org/protobuf/internal/genid"
 	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
+	preg "google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // MessageInfo provides protobuf related functionality for a given Go type
@@ -212,4 +213,53 @@
 func (mi *MessageInfo) Zero() protoreflect.Message {
 	return mi.MessageOf(reflect.Zero(mi.GoReflectType).Interface())
 }
-func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor { return mi.Desc }
+func (mi *MessageInfo) Descriptor() protoreflect.MessageDescriptor {
+	return mi.Desc
+}
+func (mi *MessageInfo) Enum(i int) protoreflect.EnumType {
+	mi.init()
+	fd := mi.Desc.Fields().Get(i)
+	return Export{}.EnumTypeOf(mi.fieldTypes[fd.Number()])
+}
+func (mi *MessageInfo) Message(i int) protoreflect.MessageType {
+	mi.init()
+	fd := mi.Desc.Fields().Get(i)
+	switch {
+	case fd.IsWeak():
+		mt, _ := preg.GlobalTypes.FindMessageByName(fd.Message().FullName())
+		return mt
+	case fd.IsMap():
+		return mapEntryType{fd.Message(), mi.fieldTypes[fd.Number()]}
+	default:
+		return Export{}.MessageTypeOf(mi.fieldTypes[fd.Number()])
+	}
+}
+
+type mapEntryType struct {
+	desc    protoreflect.MessageDescriptor
+	valType interface{} // zero value of enum or message type
+}
+
+func (mt mapEntryType) New() protoreflect.Message {
+	return nil
+}
+func (mt mapEntryType) Zero() protoreflect.Message {
+	return nil
+}
+func (mt mapEntryType) Descriptor() protoreflect.MessageDescriptor {
+	return mt.desc
+}
+func (mt mapEntryType) Enum(i int) protoreflect.EnumType {
+	fd := mt.desc.Fields().Get(i)
+	if fd.Enum() == nil {
+		return nil
+	}
+	return Export{}.EnumTypeOf(mt.valType)
+}
+func (mt mapEntryType) Message(i int) protoreflect.MessageType {
+	fd := mt.desc.Fields().Get(i)
+	if fd.Message() == nil {
+		return nil
+	}
+	return Export{}.MessageTypeOf(mt.valType)
+}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index f0bb02f..41cd090 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -17,6 +17,11 @@
 	fields map[pref.FieldNumber]*fieldInfo
 	oneofs map[pref.Name]*oneofInfo
 
+	// fieldTypes contains the zero value of an enum or message field.
+	// For lists, it contains the element type.
+	// For maps, it contains the entry value type.
+	fieldTypes map[pref.FieldNumber]interface{}
+
 	// denseFields is a subset of fields where:
 	//	0 < fieldDesc.Number() < len(denseFields)
 	// It provides faster access to the fieldInfo, but may be incomplete.
@@ -37,6 +42,7 @@
 	mi.makeKnownFieldsFunc(si)
 	mi.makeUnknownFieldsFunc(t, si)
 	mi.makeExtensionFieldsFunc(t, si)
+	mi.makeFieldTypes(si)
 }
 
 // makeKnownFieldsFunc generates functions for operations that can be performed
@@ -62,7 +68,7 @@
 			fi = fieldInfoForList(fd, fs, mi.Exporter)
 		case fd.IsWeak():
 			fi = fieldInfoForWeakMessage(fd, si.weakOffset)
-		case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind:
+		case fd.Message() != nil:
 			fi = fieldInfoForMessage(fd, fs, mi.Exporter)
 		default:
 			fi = fieldInfoForScalar(fd, fs, mi.Exporter)
@@ -146,6 +152,45 @@
 		}
 	}
 }
+func (mi *MessageInfo) makeFieldTypes(si structInfo) {
+	md := mi.Desc
+	fds := md.Fields()
+	for i := 0; i < fds.Len(); i++ {
+		var ft reflect.Type
+		fd := fds.Get(i)
+		fs := si.fieldsByNumber[fd.Number()]
+		switch {
+		case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
+			if fd.Enum() != nil || fd.Message() != nil {
+				ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
+			}
+		case fd.IsMap():
+			if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
+				ft = fs.Type.Elem()
+			}
+		case fd.IsList():
+			if fd.Enum() != nil || fd.Message() != nil {
+				ft = fs.Type.Elem()
+			}
+		case fd.Enum() != nil:
+			ft = fs.Type
+			if fd.HasPresence() {
+				ft = ft.Elem()
+			}
+		case fd.Message() != nil:
+			ft = fs.Type
+			if fd.IsWeak() {
+				ft = nil
+			}
+		}
+		if ft != nil {
+			if mi.fieldTypes == nil {
+				mi.fieldTypes = make(map[pref.FieldNumber]interface{})
+			}
+			mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
+		}
+	}
+}
 
 type extensionMap map[int32]ExtensionField
 
@@ -313,7 +358,6 @@
 // pointer to a named Go struct. If the provided type has a ProtoReflect method,
 // it must be implemented by calling this method.
 func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
-	// TODO: Switch the input to be an opaque Pointer.
 	if reflect.TypeOf(m) != mi.GoReflectType {
 		panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
 	}
diff --git a/internal/testprotos/irregular/irregular.go b/internal/testprotos/irregular/irregular.go
index a663e87..48fbd16 100644
--- a/internal/testprotos/irregular/irregular.go
+++ b/internal/testprotos/irregular/irregular.go
@@ -22,10 +22,15 @@
 
 type message IrregularMessage
 
-func (m *message) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
-func (m *message) Type() pref.MessageType             { return m }
+type messageType struct{}
+
+func (messageType) New() pref.Message                  { return &message{} }
+func (messageType) Zero() pref.Message                 { return (*message)(nil) }
+func (messageType) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
+
 func (m *message) New() pref.Message                  { return &message{} }
-func (m *message) Zero() pref.Message                 { return (*message)(nil) }
+func (m *message) Descriptor() pref.MessageDescriptor { return fileDesc.Messages().Get(0) }
+func (m *message) Type() pref.MessageType             { return messageType{} }
 func (m *message) Interface() pref.ProtoMessage       { return (*IrregularMessage)(m) }
 func (m *message) ProtoMethods() *protoiface.Methods  { return nil }
 
diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go
index 58034ef..8e53c44 100644
--- a/reflect/protoreflect/type.go
+++ b/reflect/protoreflect/type.go
@@ -232,11 +232,15 @@
 type isMessageDescriptor interface{ ProtoType(MessageDescriptor) }
 
 // MessageType encapsulates a MessageDescriptor with a concrete Go implementation.
+// It is recommended that implementations of this interface also implement the
+// MessageFieldTypes interface.
 type MessageType interface {
 	// New returns a newly allocated empty message.
+	// It may return nil for synthetic messages representing a map entry.
 	New() Message
 
 	// Zero returns an empty, read-only message.
+	// It may return nil for synthetic messages representing a map entry.
 	Zero() Message
 
 	// Descriptor returns the message descriptor.
@@ -245,6 +249,26 @@
 	Descriptor() MessageDescriptor
 }
 
+// MessageFieldTypes extends a MessageType by providing type information
+// regarding enums and messages referenced by the message fields.
+type MessageFieldTypes interface {
+	MessageType
+
+	// Enum returns the EnumType for the ith field in Descriptor.Fields.
+	// It returns nil if the ith field is not an enum kind.
+	// It panics if out of bounds.
+	//
+	// Invariant: mt.Enum(i).Descriptor() == mt.Descriptor().Fields(i).Enum()
+	Enum(i int) EnumType
+
+	// Message returns the MessageType for the ith field in Descriptor.Fields.
+	// It returns nil if the ith field is not a message or group kind.
+	// It panics if out of bounds.
+	//
+	// Invariant: mt.Message(i).Descriptor() == mt.Descriptor().Fields(i).Message()
+	Message(i int) MessageType
+}
+
 // MessageDescriptors is a list of message declarations.
 type MessageDescriptors interface {
 	// Len reports the number of messages.
diff --git a/testing/prototest/prototest.go b/testing/prototest/message.go
similarity index 86%
rename from testing/prototest/prototest.go
rename to testing/prototest/message.go
index 55a61ae..e495628 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/message.go
@@ -11,11 +11,13 @@
 	"math"
 	"reflect"
 	"sort"
+	"strings"
 	"testing"
 
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/reflect/protoregistry"
 )
@@ -96,6 +98,112 @@
 	if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want {
 		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want)
 	}
+	if mt, ok := mt.(pref.MessageFieldTypes); ok {
+		testFieldTypes(t, mt)
+	}
+}
+
+func testFieldTypes(t testing.TB, mt pref.MessageFieldTypes) {
+	descName := func(d pref.Descriptor) pref.FullName {
+		if d == nil {
+			return "<nil>"
+		}
+		return d.FullName()
+	}
+	typeName := func(mt pref.MessageType) pref.FullName {
+		if mt == nil {
+			return "<nil>"
+		}
+		return mt.Descriptor().FullName()
+	}
+	adjustExpr := func(idx int, expr string) string {
+		expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1)
+		expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1)
+		expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1)
+		expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1)
+		return expr
+	}
+	checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) {
+		if got != want {
+			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
+		}
+	}
+	checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) {
+		if got != want {
+			t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
+		}
+	}
+	checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) {
+		if got != want {
+			t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want))
+		}
+	}
+
+	fds := mt.Descriptor().Fields()
+	m := mt.New()
+	for i := 0; i < fds.Len(); i++ {
+		fd := fds.Get(i)
+		switch {
+		case fd.IsList():
+			if fd.Enum() != nil {
+				checkEnumDesc(i,
+					"mt.Enum(i).Descriptor()", "fd.Enum()",
+					mt.Enum(i).Descriptor(), fd.Enum())
+			}
+			if fd.Message() != nil {
+				checkMessageDesc(i,
+					"mt.Message(i).Descriptor()", "fd.Message()",
+					mt.Message(i).Descriptor(), fd.Message())
+				checkMessageType(i,
+					"mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()",
+					mt.Message(i), m.NewField(fd).List().NewElement().Message().Type())
+			}
+		case fd.IsMap():
+			mti := mt.Message(i)
+			if m := mti.New(); m != nil {
+				checkMessageDesc(i,
+					"m.Descriptor()", "fd.Message()",
+					m.Descriptor(), fd.Message())
+			}
+			if m := mti.Zero(); m != nil {
+				checkMessageDesc(i,
+					"m.Descriptor()", "fd.Message()",
+					m.Descriptor(), fd.Message())
+			}
+			checkMessageDesc(i,
+				"mti.Descriptor()", "fd.Message()",
+				mti.Descriptor(), fd.Message())
+			if mti := mti.(pref.MessageFieldTypes); mti != nil {
+				if fd.MapValue().Enum() != nil {
+					checkEnumDesc(i,
+						"mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()",
+						mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum())
+				}
+				if fd.MapValue().Message() != nil {
+					checkMessageDesc(i,
+						"mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()",
+						mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message())
+					checkMessageType(i,
+						"mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()",
+						mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type())
+				}
+			}
+		default:
+			if fd.Enum() != nil {
+				checkEnumDesc(i,
+					"mt.Enum(i).Descriptor()", "fd.Enum()",
+					mt.Enum(i).Descriptor(), fd.Enum())
+			}
+			if fd.Message() != nil {
+				checkMessageDesc(i,
+					"mt.Message(i).Descriptor()", "fd.Message()",
+					mt.Message(i).Descriptor(), fd.Message())
+				checkMessageType(i,
+					"mt.Message(i)", "m.NewField(fd).Message().Type()",
+					mt.Message(i), m.NewField(fd).Message().Type())
+			}
+		}
+	}
 }
 
 // testField exercises set/get/has/clear of a field.
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 7db6e55..900b9d2 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -369,6 +369,18 @@
 func (mt messageType) New() pref.Message                  { return NewMessage(mt.desc) }
 func (mt messageType) Zero() pref.Message                 { return &Message{typ: messageType{mt.desc}} }
 func (mt messageType) Descriptor() pref.MessageDescriptor { return mt.desc }
+func (mt messageType) Enum(i int) pref.EnumType {
+	if ed := mt.desc.Fields().Get(i).Enum(); ed != nil {
+		return NewEnumType(ed)
+	}
+	return nil
+}
+func (mt messageType) Message(i int) pref.MessageType {
+	if md := mt.desc.Fields().Get(i).Message(); md != nil {
+		return NewMessageType(md)
+	}
+	return nil
+}
 
 type emptyList struct {
 	desc pref.FieldDescriptor