internal/value: expose Converter.{MessageType,EnumType}

Rather than having the Converter carry a NewMessage method, have the struct
simply expose the MessageType or EnumType since they carry more information
and are retrieved anyways as part of the functionality of NewConverter.
While changing Converter, export the fields and remove all the methods.
Also, add an IsLegacy boolean, which is useful for the later implementation
of the extension fields.

Add a wrapLegacyEnum function which is used to wrap v1 enums as v2 enums.
We use this functionality in NewLegacyConverter to detrive the EnumType.
Additionally, modify wrapLegacyMessage to return a protoreflect.ProtoMessage
to be consistent with wrapLegacyEnum which must return a protoreflect.ProtoEnum.

Change-Id: Idc8989d07e4895d30de4ebc22c9ffa7357815cad
Reviewed-on: https://go-review.googlesource.com/c/148827
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/impl/legacy_enum.go b/internal/impl/legacy_enum.go
index 4f83c99..87b4921 100644
--- a/internal/impl/legacy_enum.go
+++ b/internal/impl/legacy_enum.go
@@ -11,10 +11,63 @@
 	"sync"
 
 	descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor"
+	pvalue "github.com/golang/protobuf/v2/internal/value"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
 
+var enumTypeCache sync.Map // map[reflect.Type]protoreflect.EnumType
+
+// wrapLegacyEnum wraps v as a protoreflect.ProtoEnum,
+// where v must be an int32 kind and not implement the v2 API already.
+func wrapLegacyEnum(v reflect.Value) pref.ProtoEnum {
+	// Fast-path: check if a EnumType is cached for this concrete type.
+	if et, ok := enumTypeCache.Load(v.Type()); ok {
+		return et.(pref.EnumType).New(pref.EnumNumber(v.Int()))
+	}
+
+	// Slow-path: derive enum descriptor and initialize EnumType.
+	var m sync.Map // map[protoreflect.EnumNumber]proto.Enum
+	ed := loadEnumDesc(v.Type())
+	et := ptype.GoEnum(ed, func(et pref.EnumType, n pref.EnumNumber) pref.ProtoEnum {
+		if e, ok := m.Load(n); ok {
+			return e.(pref.ProtoEnum)
+		}
+		e := &legacyEnumWrapper{num: n, pbTyp: et, goTyp: v.Type()}
+		m.Store(n, e)
+		return e
+	})
+	enumTypeCache.Store(v.Type(), et)
+	return et.(pref.EnumType).New(pref.EnumNumber(v.Int()))
+}
+
+type legacyEnumWrapper struct {
+	num   pref.EnumNumber
+	pbTyp pref.EnumType
+	goTyp reflect.Type
+}
+
+func (e *legacyEnumWrapper) Number() pref.EnumNumber {
+	return e.num
+}
+func (e *legacyEnumWrapper) Type() pref.EnumType {
+	return e.pbTyp
+}
+func (e *legacyEnumWrapper) ProtoReflect() pref.Enum {
+	return e
+}
+func (e *legacyEnumWrapper) Unwrap() interface{} {
+	v := reflect.New(e.goTyp).Elem()
+	v.SetInt(int64(e.num))
+	return v.Interface()
+}
+
+var (
+	_ pref.Enum        = (*legacyEnumWrapper)(nil)
+	_ pref.ProtoEnum   = (*legacyEnumWrapper)(nil)
+	_ pvalue.Unwrapper = (*legacyEnumWrapper)(nil)
+)
+
 var enumDescCache sync.Map // map[reflect.Type]protoreflect.EnumDescriptor
 
 // loadEnumDesc returns an EnumDescriptor derived from the Go type,
@@ -26,8 +79,8 @@
 	}
 
 	// Slow-path: initialize EnumDescriptor from the proto descriptor.
-	if t.Kind() != reflect.Int32 {
-		panic(fmt.Sprintf("got %v, want int32 kind", t))
+	if t.Kind() != reflect.Int32 || t.PkgPath() == "" {
+		panic(fmt.Sprintf("got %v, want named int32 kind", t))
 	}
 
 	// Derive the enum descriptor from the raw descriptor proto.
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index 66a71c6..3c43bfa 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -14,26 +14,55 @@
 	protoV1 "github.com/golang/protobuf/proto"
 	descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor"
 	ptag "github.com/golang/protobuf/v2/internal/encoding/tag"
+	pvalue "github.com/golang/protobuf/v2/internal/value"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
 
 var messageTypeCache sync.Map // map[reflect.Type]*MessageType
 
-// wrapLegacyMessage wraps v as a protoreflect.Message, where v must be
-// a *struct kind and not implement the v2 API already.
-func wrapLegacyMessage(v reflect.Value) pref.Message {
+// wrapLegacyMessage wraps v as a protoreflect.ProtoMessage,
+// where v must be a *struct kind and not implement the v2 API already.
+func wrapLegacyMessage(v reflect.Value) pref.ProtoMessage {
 	// Fast-path: check if a MessageType is cached for this concrete type.
 	if mt, ok := messageTypeCache.Load(v.Type()); ok {
-		return mt.(*MessageType).MessageOf(v.Interface())
+		return mt.(*MessageType).MessageOf(v.Interface()).Interface()
 	}
 
 	// Slow-path: derive message descriptor and initialize MessageType.
 	mt := &MessageType{Desc: loadMessageDesc(v.Type())}
 	messageTypeCache.Store(v.Type(), mt)
-	return mt.MessageOf(v.Interface())
+	return mt.MessageOf(v.Interface()).Interface()
 }
 
+type legacyMessageWrapper messageDataType
+
+func (m *legacyMessageWrapper) Type() pref.MessageType {
+	return m.mi.pbType
+}
+func (m *legacyMessageWrapper) KnownFields() pref.KnownFields {
+	return (*knownFields)(m)
+}
+func (m *legacyMessageWrapper) UnknownFields() pref.UnknownFields {
+	return m.mi.unknownFields((*messageDataType)(m))
+}
+func (m *legacyMessageWrapper) Unwrap() interface{} {
+	return m.p.asType(m.mi.goType.Elem()).Interface()
+}
+func (m *legacyMessageWrapper) Interface() pref.ProtoMessage {
+	return m
+}
+func (m *legacyMessageWrapper) ProtoReflect() pref.Message {
+	return m
+}
+func (m *legacyMessageWrapper) ProtoMutable() {}
+
+var (
+	_ pref.Message      = (*legacyMessageWrapper)(nil)
+	_ pref.ProtoMessage = (*legacyMessageWrapper)(nil)
+	_ pvalue.Unwrapper  = (*legacyMessageWrapper)(nil)
+)
+
 var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
 
 // loadMessageDesc returns an MessageDescriptor derived from the Go type,
@@ -85,8 +114,8 @@
 	}
 
 	// Slow-path: Walk over the struct fields to derive the message descriptor.
-	if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
-		panic(fmt.Sprintf("got %v, want *struct kind", t))
+	if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Struct || t.Elem().PkgPath() == "" {
+		panic(fmt.Sprintf("got %v, want named *struct kind", t))
 	}
 
 	// Derive name and syntax from the raw descriptor.
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c6f247e..ab889cc 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -68,7 +68,7 @@
 		if _, ok := p.(pref.ProtoMessage); !ok {
 			mi.pbType = ptype.GoMessage(mi.Desc, func(pref.MessageType) pref.ProtoMessage {
 				p := reflect.New(t.Elem()).Interface()
-				return (*message)(mi.dataTypeOf(p))
+				return (*legacyMessageWrapper)(mi.dataTypeOf(p))
 			})
 		}
 
@@ -182,7 +182,7 @@
 		// See the comment in MessageType.init regarding pbType.
 		return m.ProtoReflect()
 	}
-	return (*message)(mi.dataTypeOf(p))
+	return (*legacyMessageWrapper)(mi.dataTypeOf(p))
 }
 
 func (mi *MessageType) KnownFieldsOf(p interface{}) pref.KnownFields {
@@ -222,28 +222,6 @@
 	mi *MessageType
 }
 
-type message messageDataType
-
-func (m *message) Type() pref.MessageType {
-	return m.mi.pbType
-}
-func (m *message) KnownFields() pref.KnownFields {
-	return (*knownFields)(m)
-}
-func (m *message) UnknownFields() pref.UnknownFields {
-	return m.mi.unknownFields((*messageDataType)(m))
-}
-func (m *message) Unwrap() interface{} { // TODO: unexport?
-	return m.p.asType(m.mi.goType.Elem()).Interface()
-}
-func (m *message) Interface() pref.ProtoMessage {
-	return m
-}
-func (m *message) ProtoReflect() pref.Message {
-	return m
-}
-func (m *message) ProtoMutable() {}
-
 type knownFields messageDataType
 
 func (fs *knownFields) Len() (cnt int) {
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index 9f015cb..80b369c 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -42,7 +42,7 @@
 	if !reflect.PtrTo(ot).Implements(ft) {
 		panic(fmt.Sprintf("invalid type: %v does not implement %v", ot, ft))
 	}
-	conv := value.NewLegacyConverter(ot.Field(0).Type, fd.Kind(), wrapLegacyMessage)
+	conv := value.NewLegacyConverter(ot.Field(0).Type, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
 	fieldOffset := offsetOf(fs)
 	// TODO: Implement unsafe fast path?
 	return fieldInfo{
@@ -93,7 +93,7 @@
 			}
 			rv = rv.Elem().Elem().Field(0)
 			if rv.IsNil() {
-				pv := pref.ValueOf(conv.NewMessage())
+				pv := pref.ValueOf(conv.MessageType.New().ProtoReflect())
 				rv.Set(conv.GoValueOf(pv))
 			}
 			return rv.Interface().(pref.Message)
@@ -106,8 +106,8 @@
 	if ft.Kind() != reflect.Map {
 		panic(fmt.Sprintf("invalid type: got %v, want map kind", ft))
 	}
-	keyConv := value.NewLegacyConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind(), wrapLegacyMessage)
-	valConv := value.NewLegacyConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind(), wrapLegacyMessage)
+	keyConv := value.NewLegacyConverter(ft.Key(), fd.MessageType().Fields().ByNumber(1).Kind(), wrapLegacyEnum, wrapLegacyMessage)
+	valConv := value.NewLegacyConverter(ft.Elem(), fd.MessageType().Fields().ByNumber(2).Kind(), wrapLegacyEnum, wrapLegacyMessage)
 	fieldOffset := offsetOf(fs)
 	// TODO: Implement unsafe fast path?
 	return fieldInfo{
@@ -139,7 +139,7 @@
 	if ft.Kind() != reflect.Slice {
 		panic(fmt.Sprintf("invalid type: got %v, want slice kind", ft))
 	}
-	conv := value.NewLegacyConverter(ft.Elem(), fd.Kind(), wrapLegacyMessage)
+	conv := value.NewLegacyConverter(ft.Elem(), fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
 	fieldOffset := offsetOf(fs)
 	// TODO: Implement unsafe fast path?
 	return fieldInfo{
@@ -179,7 +179,7 @@
 			ft = ft.Elem()
 		}
 	}
-	conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyMessage)
+	conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
 	fieldOffset := offsetOf(fs)
 	// TODO: Implement unsafe fast path?
 	return fieldInfo{
@@ -244,7 +244,7 @@
 
 func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
 	ft := fs.Type
-	conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyMessage)
+	conv := value.NewLegacyConverter(ft, fd.Kind(), wrapLegacyEnum, wrapLegacyMessage)
 	fieldOffset := offsetOf(fs)
 	// TODO: Implement unsafe fast path?
 	return fieldInfo{
@@ -275,7 +275,7 @@
 			// Mutable is only valid for messages and panics for other kinds.
 			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
 			if rv.IsNil() {
-				pv := pref.ValueOf(conv.NewMessage())
+				pv := pref.ValueOf(conv.MessageType.New().ProtoReflect())
 				rv.Set(conv.GoValueOf(pv))
 			}
 			return conv.PBValueOf(rv).Message()
diff --git a/internal/value/convert.go b/internal/value/convert.go
index cd3402a..e4692a7 100644
--- a/internal/value/convert.go
+++ b/internal/value/convert.go
@@ -53,14 +53,17 @@
 // protoc-gen-go historically generated to be able to automatically wrap some
 // v1 messages generated by other forks of protoc-gen-go.
 func NewConverter(t reflect.Type, k pref.Kind) Converter {
-	return NewLegacyConverter(t, k, nil)
+	return NewLegacyConverter(t, k, nil, nil)
 }
 
 // NewLegacyConverter is identical to NewConverter,
 // but supports wrapping legacy v1 messages to implement the v2 message API
-// using the provided wrapLegacyMessage function.
+// using the provided wrapEnum and wrapMessage functions.
 // The wrapped message must implement Unwrapper.
-func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapLegacyMessage func(reflect.Value) pref.Message) Converter {
+func NewLegacyConverter(t reflect.Type, k pref.Kind, wrapEnum func(reflect.Value) pref.ProtoEnum, wrapMessage func(reflect.Value) pref.ProtoMessage) Converter {
+	if (wrapEnum == nil) != (wrapMessage == nil) {
+		panic("legacy enum and message wrappers must both be populated or nil")
+	}
 	switch k {
 	case pref.BoolKind:
 		if t.Kind() == reflect.Bool {
@@ -103,35 +106,39 @@
 		if t.Kind() != reflect.Ptr && t.Implements(enumIfaceV2) {
 			et := reflect.Zero(t).Interface().(pref.ProtoEnum).ProtoReflect().Type()
 			return Converter{
-				toPB: func(v reflect.Value) pref.Value {
+				PBValueOf: func(v reflect.Value) pref.Value {
 					if v.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
 					}
 					e := v.Interface().(pref.ProtoEnum)
 					return pref.ValueOf(e.ProtoReflect().Number())
 				},
-				toGo: func(v pref.Value) reflect.Value {
+				GoValueOf: func(v pref.Value) reflect.Value {
 					rv := reflect.ValueOf(et.New(v.Enum()))
 					if rv.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
 					}
 					return rv
 				},
+				EnumType: et,
 			}
 		}
 
 		// Handle v1 enums, which we identify as simply a named int32 type.
-		if wrapLegacyMessage != nil && t.Kind() == reflect.Int32 && t.PkgPath() != "" {
+		if wrapEnum != nil && t.PkgPath() != "" && t.Kind() == reflect.Int32 {
+			et := wrapEnum(reflect.Zero(t)).ProtoReflect().Type()
 			return Converter{
-				toPB: func(v reflect.Value) pref.Value {
+				PBValueOf: func(v reflect.Value) pref.Value {
 					if v.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
 					}
 					return pref.ValueOf(pref.EnumNumber(v.Int()))
 				},
-				toGo: func(v pref.Value) reflect.Value {
+				GoValueOf: func(v pref.Value) reflect.Value {
 					return reflect.ValueOf(v.Enum()).Convert(t)
 				},
+				EnumType: et,
+				IsLegacy: true,
 			}
 		}
 	case pref.MessageKind, pref.GroupKind:
@@ -139,44 +146,42 @@
 		if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV2) {
 			mt := reflect.Zero(t).Interface().(pref.ProtoMessage).ProtoReflect().Type()
 			return Converter{
-				toPB: func(v reflect.Value) pref.Value {
+				PBValueOf: func(v reflect.Value) pref.Value {
 					if v.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
 					}
 					return pref.ValueOf(v.Interface())
 				},
-				toGo: func(v pref.Value) reflect.Value {
+				GoValueOf: func(v pref.Value) reflect.Value {
 					rv := reflect.ValueOf(v.Message())
 					if rv.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
 					}
 					return rv
 				},
-				newMessage: func() pref.Message {
-					return mt.New().ProtoReflect()
-				},
+				MessageType: mt,
 			}
 		}
 
 		// Handle v1 messages, which we need to wrap as a v2 message.
-		if wrapLegacyMessage != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
+		if wrapMessage != nil && t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
+			mt := wrapMessage(reflect.New(t.Elem())).ProtoReflect().Type()
 			return Converter{
-				toPB: func(v reflect.Value) pref.Value {
+				PBValueOf: func(v reflect.Value) pref.Value {
 					if v.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
 					}
-					return pref.ValueOf(wrapLegacyMessage(v))
+					return pref.ValueOf(wrapMessage(v).ProtoReflect())
 				},
-				toGo: func(v pref.Value) reflect.Value {
+				GoValueOf: func(v pref.Value) reflect.Value {
 					rv := reflect.ValueOf(v.Message().(Unwrapper).Unwrap())
 					if rv.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
 					}
 					return rv
 				},
-				newMessage: func() pref.Message {
-					return wrapLegacyMessage(reflect.New(t.Elem()))
-				},
+				MessageType: mt,
+				IsLegacy:    true,
 			}
 		}
 	}
@@ -185,7 +190,7 @@
 
 func makeScalarConverter(goType, pbType reflect.Type) Converter {
 	return Converter{
-		toPB: func(v reflect.Value) pref.Value {
+		PBValueOf: func(v reflect.Value) pref.Value {
 			if v.Type() != goType {
 				panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), goType))
 			}
@@ -194,7 +199,7 @@
 			}
 			return pref.ValueOf(v.Convert(pbType).Interface())
 		},
-		toGo: func(v pref.Value) reflect.Value {
+		GoValueOf: func(v pref.Value) reflect.Value {
 			rv := reflect.ValueOf(v.Interface())
 			if rv.Type() != pbType {
 				panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), pbType))
@@ -210,11 +215,9 @@
 // Converter provides functions for converting to/from Go reflect.Value types
 // and protobuf protoreflect.Value types.
 type Converter struct {
-	toPB       func(reflect.Value) pref.Value
-	toGo       func(pref.Value) reflect.Value
-	newMessage func() pref.Message
+	PBValueOf   func(reflect.Value) pref.Value
+	GoValueOf   func(pref.Value) reflect.Value
+	EnumType    pref.EnumType
+	MessageType pref.MessageType
+	IsLegacy    bool
 }
-
-func (c Converter) PBValueOf(v reflect.Value) pref.Value { return c.toPB(v) }
-func (c Converter) GoValueOf(v pref.Value) reflect.Value { return c.toGo(v) }
-func (c Converter) NewMessage() pref.Message             { return c.newMessage() }
diff --git a/internal/value/map.go b/internal/value/map.go
index c1590d0..c051074 100644
--- a/internal/value/map.go
+++ b/internal/value/map.go
@@ -59,7 +59,7 @@
 	rv := ms.v.MapIndex(rk)
 	if !rv.IsValid() || rv.IsNil() {
 		// TODO: Is checking for nil proper behavior for custom messages?
-		pv := pref.ValueOf(ms.valConv.NewMessage())
+		pv := pref.ValueOf(ms.valConv.MessageType.New().ProtoReflect())
 		rv = ms.valConv.GoValueOf(pv)
 		ms.v.SetMapIndex(rk, rv)
 	}
diff --git a/internal/value/vector.go b/internal/value/vector.go
index 4c1a78d..664ece2 100644
--- a/internal/value/vector.go
+++ b/internal/value/vector.go
@@ -38,14 +38,14 @@
 	rv := vs.v.Index(i)
 	if rv.IsNil() {
 		// TODO: Is checking for nil proper behavior for custom messages?
-		pv := pref.ValueOf(vs.conv.NewMessage())
+		pv := pref.ValueOf(vs.conv.MessageType.New().ProtoReflect())
 		rv.Set(vs.conv.GoValueOf(pv))
 	}
 	return rv.Interface().(pref.Message)
 }
 func (vs vectorReflect) MutableAppend() pref.Mutable {
 	// MutableAppend is only valid for messages and panics for other kinds.
-	pv := pref.ValueOf(vs.conv.NewMessage())
+	pv := pref.ValueOf(vs.conv.MessageType.New().ProtoReflect())
 	vs.v.Set(reflect.Append(vs.v, vs.conv.GoValueOf(pv)))
 	return vs.v.Index(vs.Len() - 1).Interface().(pref.Message)
 }