internal/impl: support message and enum fields

Dynamically generate functions for handling message and enum fields,
regardless of whether they are of the v1 or v2 forms.

If a v1 message is encountered, it is automatically wrapped such that it
implements the v2 interface.

Change-Id: I457bc5286892e8fc00a61da7062dd33058daafd5
Reviewed-on: https://go-review.googlesource.com/c/143837
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index a617249..66a71c6 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -18,10 +18,26 @@
 	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
 
+var messageTypeCache sync.Map // map[reflect.Type]*MessageType
+
+// wrapLegacyMessage wraps v as a protoreflect.Message, where v must be
+// a *struct kind and not implement the v2 API already.
+func wrapLegacyMessage(v reflect.Value) pref.Message {
+	// Fast-path: check if a MessageType is cached for this concrete type.
+	if mt, ok := messageTypeCache.Load(v.Type()); ok {
+		return mt.(*MessageType).MessageOf(v.Interface())
+	}
+
+	// Slow-path: derive message descriptor and initialize MessageType.
+	mt := &MessageType{Desc: loadMessageDesc(v.Type())}
+	messageTypeCache.Store(v.Type(), mt)
+	return mt.MessageOf(v.Interface())
+}
+
 var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
 
 // loadMessageDesc returns an MessageDescriptor derived from the Go type,
-// which must be an *struct kind and not implement the v2 API already.
+// which must be a *struct kind and not implement the v2 API already.
 func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
 	return messageDescSet{}.Load(t)
 }
diff --git a/internal/impl/message_field.go b/internal/impl/message_field.go
index bdab90a..140f5cd 100644
--- a/internal/impl/message_field.go
+++ b/internal/impl/message_field.go
@@ -60,8 +60,7 @@
 			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
 			if rv.IsNil() || rv.Elem().Type().Elem() != ot {
 				if fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind {
-					// Return a typed nil pointer of the message type to be
-					// consistent with the behavior of generated getters.
+					// TODO: Should this return an invalid protoreflect.Value?
 					rv = reflect.Zero(ot.Field(0).Type)
 					return conv.toPB(rv)
 				}
@@ -198,6 +197,8 @@
 }
 func (ms mapReflect) ProtoMutable() {}
 
+var _ pref.Map = mapReflect{}
+
 func fieldInfoForVector(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
 	ft := fs.Type
 	if ft.Kind() != reflect.Slice {
@@ -349,8 +350,44 @@
 }
 
 func fieldInfoForMessage(fd pref.FieldDescriptor, fs reflect.StructField) fieldInfo {
-	// TODO: support vector fields.
-	panic(fmt.Sprintf("invalid field: %v", fd))
+	ft := fs.Type
+	conv := matchGoTypePBKind(ft, fd.Kind())
+	fieldOffset := offsetOf(fs)
+	// TODO: Implement unsafe fast path?
+	return fieldInfo{
+		has: func(p pointer) bool {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			return !rv.IsNil()
+		},
+		get: func(p pointer) pref.Value {
+			// TODO: If rv.IsNil(), should this return a typed-nil pointer or
+			// an invalid protoreflect.Value?
+			//
+			// Returning a typed nil pointer assumes that such values
+			// are valid for all possible custom message types,
+			// which may not be case for dynamic messages.
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			return conv.toPB(rv)
+		},
+		set: func(p pointer, v pref.Value) {
+			// TODO: Similarly, is it valid to set this to a typed nil pointer?
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			rv.Set(conv.toGo(v))
+		},
+		clear: func(p pointer) {
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			rv.Set(reflect.Zero(rv.Type()))
+		},
+		mutable: func(p pointer) pref.Mutable {
+			// Mutable is only valid for messages and panics for other kinds.
+			rv := p.apply(fieldOffset).asType(fs.Type).Elem()
+			if rv.IsNil() {
+				pv := pref.ValueOf(conv.newMessage())
+				rv.Set(conv.toGo(pv))
+			}
+			return conv.toPB(rv).Message()
+		},
+	}
 }
 
 // messageV1 is the protoV1.Message interface.
@@ -424,22 +461,84 @@
 	case pref.EnumKind:
 		// Handle v2 enums, which must satisfy the proto.Enum interface.
 		if t.Kind() != reflect.Ptr && t.Implements(enumIfaceV2) {
-			// TODO: implement this.
+			et := reflect.Zero(t).Interface().(pref.ProtoEnum).ProtoReflect().Type()
+			return converter{
+				toPB: func(v reflect.Value) pref.Value {
+					if v.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+					}
+					e := v.Interface().(pref.ProtoEnum)
+					return pref.ValueOf(e.ProtoReflect().Number())
+				},
+				toGo: func(v pref.Value) reflect.Value {
+					rv := reflect.ValueOf(et.GoNew(v.Enum()))
+					if rv.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+					}
+					return rv
+				},
+			}
 		}
 
 		// Handle v1 enums, which we identify as simply a named int32 type.
 		if t.Kind() == reflect.Int32 && t.PkgPath() != "" {
-			// TODO: need logic to wrap a legacy enum to implement this.
+			return converter{
+				toPB: func(v reflect.Value) pref.Value {
+					if v.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+					}
+					return pref.ValueOf(pref.EnumNumber(v.Int()))
+				},
+				toGo: func(v pref.Value) reflect.Value {
+					return reflect.ValueOf(v.Enum()).Convert(t)
+				},
+			}
 		}
 	case pref.MessageKind, pref.GroupKind:
 		// Handle v2 messages, which must satisfy the proto.Message interface.
 		if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV2) {
-			// TODO: implement this.
+			mt := reflect.Zero(t).Interface().(pref.ProtoMessage).ProtoReflect().Type()
+			return converter{
+				toPB: func(v reflect.Value) pref.Value {
+					if v.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+					}
+					return pref.ValueOf(v.Interface())
+				},
+				toGo: func(v pref.Value) reflect.Value {
+					rv := reflect.ValueOf(v.Message())
+					if rv.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+					}
+					return rv
+				},
+				newMessage: func() pref.Message {
+					return mt.GoNew().ProtoReflect()
+				},
+			}
 		}
 
 		// Handle v1 messages, which we need to wrap as a v2 message.
 		if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV1) {
-			// TODO: need logic to wrap a legacy message.
+			return converter{
+				toPB: func(v reflect.Value) pref.Value {
+					if v.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t))
+					}
+					return pref.ValueOf(wrapLegacyMessage(v))
+				},
+				toGo: func(v pref.Value) reflect.Value {
+					type unwrapper interface{ Unwrap() interface{} }
+					rv := reflect.ValueOf(v.Message().(unwrapper).Unwrap())
+					if rv.Type() != t {
+						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
+					}
+					return rv
+				},
+				newMessage: func() pref.Message {
+					return wrapLegacyMessage(reflect.New(t.Elem()))
+				},
+			}
 		}
 	}
 	panic(fmt.Sprintf("invalid Go type %v for protobuf kind %v", t, k))
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index ce8bc70..4041f02 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -41,13 +41,34 @@
 	MyString  string
 	MyBytes   []byte
 
-	NamedStrings []MyString
-	NamedBytes   []MyBytes
+	VectorStrings []MyString
+	VectorBytes   []MyBytes
 
 	MapStrings map[MyString]MyString
 	MapBytes   map[MyString]MyBytes
+
+	MyEnumV1 pref.EnumNumber
+	MyEnumV2 string
+	myEnumV2 MyEnumV2
+
+	MyMessageV1 struct {
+		// SubMessage *Message
+	}
+	MyMessageV2 map[pref.FieldNumber]pref.Value
+	myMessageV2 MyMessageV2
 )
 
+func (e MyEnumV2) ProtoReflect() pref.Enum { return myEnumV2(e) }
+func (e myEnumV2) Type() pref.EnumType     { return nil } // TODO
+func (e myEnumV2) Number() pref.EnumNumber { return 0 }   // TODO
+
+func (m *MyMessageV2) ProtoReflect() pref.Message        { return (*myMessageV2)(m) }
+func (m *myMessageV2) Type() pref.MessageType            { return nil } // TODO
+func (m *myMessageV2) KnownFields() pref.KnownFields     { return nil } // TODO
+func (m *myMessageV2) UnknownFields() pref.UnknownFields { return nil } // TODO
+func (m *myMessageV2) Interface() pref.ProtoMessage      { return (*MyMessageV2)(m) }
+func (m *myMessageV2) ProtoMutable()                     {}
+
 // List of test operations to perform on messages, vectors, or maps.
 type (
 	messageOp  interface{} // equalMessage | hasFields | getFields | setFields | clearFields | vectorFields | mapFields
@@ -96,33 +117,33 @@
 	// TODO: Mutable
 )
 
+type ScalarProto2 struct {
+	Bool    *bool    `protobuf:"1"`
+	Int32   *int32   `protobuf:"2"`
+	Int64   *int64   `protobuf:"3"`
+	Uint32  *uint32  `protobuf:"4"`
+	Uint64  *uint64  `protobuf:"5"`
+	Float32 *float32 `protobuf:"6"`
+	Float64 *float64 `protobuf:"7"`
+	String  *string  `protobuf:"8"`
+	StringA []byte   `protobuf:"9"`
+	Bytes   []byte   `protobuf:"10"`
+	BytesA  *string  `protobuf:"11"`
+
+	MyBool    *MyBool    `protobuf:"12"`
+	MyInt32   *MyInt32   `protobuf:"13"`
+	MyInt64   *MyInt64   `protobuf:"14"`
+	MyUint32  *MyUint32  `protobuf:"15"`
+	MyUint64  *MyUint64  `protobuf:"16"`
+	MyFloat32 *MyFloat32 `protobuf:"17"`
+	MyFloat64 *MyFloat64 `protobuf:"18"`
+	MyString  *MyString  `protobuf:"19"`
+	MyStringA MyBytes    `protobuf:"20"`
+	MyBytes   MyBytes    `protobuf:"21"`
+	MyBytesA  *MyString  `protobuf:"22"`
+}
+
 func TestScalarProto2(t *testing.T) {
-	type ScalarProto2 struct {
-		Bool    *bool    `protobuf:"1"`
-		Int32   *int32   `protobuf:"2"`
-		Int64   *int64   `protobuf:"3"`
-		Uint32  *uint32  `protobuf:"4"`
-		Uint64  *uint64  `protobuf:"5"`
-		Float32 *float32 `protobuf:"6"`
-		Float64 *float64 `protobuf:"7"`
-		String  *string  `protobuf:"8"`
-		StringA []byte   `protobuf:"9"`
-		Bytes   []byte   `protobuf:"10"`
-		BytesA  *string  `protobuf:"11"`
-
-		MyBool    *MyBool    `protobuf:"12"`
-		MyInt32   *MyInt32   `protobuf:"13"`
-		MyInt64   *MyInt64   `protobuf:"14"`
-		MyUint32  *MyUint32  `protobuf:"15"`
-		MyUint64  *MyUint64  `protobuf:"16"`
-		MyFloat32 *MyFloat32 `protobuf:"17"`
-		MyFloat64 *MyFloat64 `protobuf:"18"`
-		MyString  *MyString  `protobuf:"19"`
-		MyStringA MyBytes    `protobuf:"20"`
-		MyBytes   MyBytes    `protobuf:"21"`
-		MyBytesA  *MyString  `protobuf:"22"`
-	}
-
 	mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
 		Syntax:   pref.Proto2,
 		FullName: "ScalarProto2",
@@ -182,33 +203,33 @@
 	})
 }
 
+type ScalarProto3 struct {
+	Bool    bool    `protobuf:"1"`
+	Int32   int32   `protobuf:"2"`
+	Int64   int64   `protobuf:"3"`
+	Uint32  uint32  `protobuf:"4"`
+	Uint64  uint64  `protobuf:"5"`
+	Float32 float32 `protobuf:"6"`
+	Float64 float64 `protobuf:"7"`
+	String  string  `protobuf:"8"`
+	StringA []byte  `protobuf:"9"`
+	Bytes   []byte  `protobuf:"10"`
+	BytesA  string  `protobuf:"11"`
+
+	MyBool    MyBool    `protobuf:"12"`
+	MyInt32   MyInt32   `protobuf:"13"`
+	MyInt64   MyInt64   `protobuf:"14"`
+	MyUint32  MyUint32  `protobuf:"15"`
+	MyUint64  MyUint64  `protobuf:"16"`
+	MyFloat32 MyFloat32 `protobuf:"17"`
+	MyFloat64 MyFloat64 `protobuf:"18"`
+	MyString  MyString  `protobuf:"19"`
+	MyStringA MyBytes   `protobuf:"20"`
+	MyBytes   MyBytes   `protobuf:"21"`
+	MyBytesA  MyString  `protobuf:"22"`
+}
+
 func TestScalarProto3(t *testing.T) {
-	type ScalarProto3 struct {
-		Bool    bool    `protobuf:"1"`
-		Int32   int32   `protobuf:"2"`
-		Int64   int64   `protobuf:"3"`
-		Uint32  uint32  `protobuf:"4"`
-		Uint64  uint64  `protobuf:"5"`
-		Float32 float32 `protobuf:"6"`
-		Float64 float64 `protobuf:"7"`
-		String  string  `protobuf:"8"`
-		StringA []byte  `protobuf:"9"`
-		Bytes   []byte  `protobuf:"10"`
-		BytesA  string  `protobuf:"11"`
-
-		MyBool    MyBool    `protobuf:"12"`
-		MyInt32   MyInt32   `protobuf:"13"`
-		MyInt64   MyInt64   `protobuf:"14"`
-		MyUint32  MyUint32  `protobuf:"15"`
-		MyUint64  MyUint64  `protobuf:"16"`
-		MyFloat32 MyFloat32 `protobuf:"17"`
-		MyFloat64 MyFloat64 `protobuf:"18"`
-		MyString  MyString  `protobuf:"19"`
-		MyStringA MyBytes   `protobuf:"20"`
-		MyBytes   MyBytes   `protobuf:"21"`
-		MyBytesA  MyString  `protobuf:"22"`
-	}
-
 	mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
 		Syntax:   pref.Proto3,
 		FullName: "ScalarProto3",
@@ -277,31 +298,31 @@
 	})
 }
 
+type RepeatedScalars struct {
+	Bools    []bool    `protobuf:"1"`
+	Int32s   []int32   `protobuf:"2"`
+	Int64s   []int64   `protobuf:"3"`
+	Uint32s  []uint32  `protobuf:"4"`
+	Uint64s  []uint64  `protobuf:"5"`
+	Float32s []float32 `protobuf:"6"`
+	Float64s []float64 `protobuf:"7"`
+	Strings  []string  `protobuf:"8"`
+	StringsA [][]byte  `protobuf:"9"`
+	Bytes    [][]byte  `protobuf:"10"`
+	BytesA   []string  `protobuf:"11"`
+
+	MyStrings1 []MyString `protobuf:"12"`
+	MyStrings2 []MyBytes  `protobuf:"13"`
+	MyBytes1   []MyBytes  `protobuf:"14"`
+	MyBytes2   []MyString `protobuf:"15"`
+
+	MyStrings3 VectorStrings `protobuf:"16"`
+	MyStrings4 VectorBytes   `protobuf:"17"`
+	MyBytes3   VectorBytes   `protobuf:"18"`
+	MyBytes4   VectorStrings `protobuf:"19"`
+}
+
 func TestRepeatedScalars(t *testing.T) {
-	type RepeatedScalars struct {
-		Bools    []bool    `protobuf:"1"`
-		Int32s   []int32   `protobuf:"2"`
-		Int64s   []int64   `protobuf:"3"`
-		Uint32s  []uint32  `protobuf:"4"`
-		Uint64s  []uint64  `protobuf:"5"`
-		Float32s []float32 `protobuf:"6"`
-		Float64s []float64 `protobuf:"7"`
-		Strings  []string  `protobuf:"8"`
-		StringsA [][]byte  `protobuf:"9"`
-		Bytes    [][]byte  `protobuf:"10"`
-		BytesA   []string  `protobuf:"11"`
-
-		MyStrings1 []MyString `protobuf:"12"`
-		MyStrings2 []MyBytes  `protobuf:"13"`
-		MyBytes1   []MyBytes  `protobuf:"14"`
-		MyBytes2   []MyString `protobuf:"15"`
-
-		MyStrings3 NamedStrings `protobuf:"16"`
-		MyStrings4 NamedBytes   `protobuf:"17"`
-		MyBytes3   NamedBytes   `protobuf:"18"`
-		MyBytes4   NamedStrings `protobuf:"19"`
-	}
-
 	mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
 		Syntax:   pref.Proto2,
 		FullName: "RepeatedScalars",
@@ -351,10 +372,10 @@
 		MyBytes1:   []MyBytes{[]byte("14"), nil, []byte("fourteen")},
 		MyBytes2:   []MyString{"15", "", "fifteen"},
 
-		MyStrings3: NamedStrings{"16", "", "sixteen"},
-		MyStrings4: NamedBytes{[]byte("17"), nil, []byte("seventeen")},
-		MyBytes3:   NamedBytes{[]byte("18"), nil, []byte("eighteen")},
-		MyBytes4:   NamedStrings{"19", "", "nineteen"},
+		MyStrings3: VectorStrings{"16", "", "sixteen"},
+		MyStrings4: VectorBytes{[]byte("17"), nil, []byte("seventeen")},
+		MyBytes3:   VectorBytes{[]byte("18"), nil, []byte("eighteen")},
+		MyBytes4:   VectorStrings{"19", "", "nineteen"},
 	})
 	wantFS := want.KnownFields()
 
@@ -418,38 +439,38 @@
 	})
 }
 
+type MapScalars struct {
+	KeyBools   map[bool]string   `protobuf:"1"`
+	KeyInt32s  map[int32]string  `protobuf:"2"`
+	KeyInt64s  map[int64]string  `protobuf:"3"`
+	KeyUint32s map[uint32]string `protobuf:"4"`
+	KeyUint64s map[uint64]string `protobuf:"5"`
+	KeyStrings map[string]string `protobuf:"6"`
+
+	ValBools    map[string]bool    `protobuf:"7"`
+	ValInt32s   map[string]int32   `protobuf:"8"`
+	ValInt64s   map[string]int64   `protobuf:"9"`
+	ValUint32s  map[string]uint32  `protobuf:"10"`
+	ValUint64s  map[string]uint64  `protobuf:"11"`
+	ValFloat32s map[string]float32 `protobuf:"12"`
+	ValFloat64s map[string]float64 `protobuf:"13"`
+	ValStrings  map[string]string  `protobuf:"14"`
+	ValStringsA map[string][]byte  `protobuf:"15"`
+	ValBytes    map[string][]byte  `protobuf:"16"`
+	ValBytesA   map[string]string  `protobuf:"17"`
+
+	MyStrings1 map[MyString]MyString `protobuf:"18"`
+	MyStrings2 map[MyString]MyBytes  `protobuf:"19"`
+	MyBytes1   map[MyString]MyBytes  `protobuf:"20"`
+	MyBytes2   map[MyString]MyString `protobuf:"21"`
+
+	MyStrings3 MapStrings `protobuf:"22"`
+	MyStrings4 MapBytes   `protobuf:"23"`
+	MyBytes3   MapBytes   `protobuf:"24"`
+	MyBytes4   MapStrings `protobuf:"25"`
+}
+
 func TestMapScalars(t *testing.T) {
-	type MapScalars struct {
-		KeyBools   map[bool]string   `protobuf:"1"`
-		KeyInt32s  map[int32]string  `protobuf:"2"`
-		KeyInt64s  map[int64]string  `protobuf:"3"`
-		KeyUint32s map[uint32]string `protobuf:"4"`
-		KeyUint64s map[uint64]string `protobuf:"5"`
-		KeyStrings map[string]string `protobuf:"6"`
-
-		ValBools    map[string]bool    `protobuf:"7"`
-		ValInt32s   map[string]int32   `protobuf:"8"`
-		ValInt64s   map[string]int64   `protobuf:"9"`
-		ValUint32s  map[string]uint32  `protobuf:"10"`
-		ValUint64s  map[string]uint64  `protobuf:"11"`
-		ValFloat32s map[string]float32 `protobuf:"12"`
-		ValFloat64s map[string]float64 `protobuf:"13"`
-		ValStrings  map[string]string  `protobuf:"14"`
-		ValStringsA map[string][]byte  `protobuf:"15"`
-		ValBytes    map[string][]byte  `protobuf:"16"`
-		ValBytesA   map[string]string  `protobuf:"17"`
-
-		MyStrings1 map[MyString]MyString `protobuf:"18"`
-		MyStrings2 map[MyString]MyBytes  `protobuf:"19"`
-		MyBytes1   map[MyString]MyBytes  `protobuf:"20"`
-		MyBytes2   map[MyString]MyString `protobuf:"21"`
-
-		MyStrings3 MapStrings `protobuf:"22"`
-		MyStrings4 MapBytes   `protobuf:"23"`
-		MyBytes3   MapBytes   `protobuf:"24"`
-		MyBytes4   MapStrings `protobuf:"25"`
-	}
-
 	mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
 		return ptype.Field{
 			Name:        pref.Name(fmt.Sprintf("f%d", n)),