reflect/protoreflect: add ExtensionType IsValid{Interface,Value} methods

Add a way to typecheck a Value or interface{} without converting it to
the other form.  This permits implementations which store field values as
a Value (such as dynamicpb, or (soon) extensions in generated messages)
to validate inputs without an unnecessary conversion.

Fixes golang/protobuf#905

Change-Id: I1b78612b22ae832efbb55f81ae420871729e3a02
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/192457
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/convert.go b/internal/impl/convert.go
index 9b74c03..2186b26 100644
--- a/internal/impl/convert.go
+++ b/internal/impl/convert.go
@@ -25,6 +25,12 @@
 	// GoValueOf converts a protoreflect.Value to a reflect.Value.
 	GoValueOf(pref.Value) reflect.Value
 
+	// IsValidPB returns whether a protoreflect.Value is compatible with this type.
+	IsValidPB(pref.Value) bool
+
+	// IsValidGo returns whether a reflect.Value is compatible with this type.
+	IsValidGo(reflect.Value) bool
+
 	// New returns a new field value.
 	// For scalars, it returns the default value of the field.
 	// For composite types, it returns a new mutable value.
@@ -151,6 +157,13 @@
 func (c *boolConverter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(v.Bool()).Convert(c.goType)
 }
+func (c *boolConverter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(bool)
+	return ok
+}
+func (c *boolConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *boolConverter) New() pref.Value  { return c.def }
 func (c *boolConverter) Zero() pref.Value { return c.def }
 
@@ -168,6 +181,13 @@
 func (c *int32Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(int32(v.Int())).Convert(c.goType)
 }
+func (c *int32Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(int32)
+	return ok
+}
+func (c *int32Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *int32Converter) New() pref.Value  { return c.def }
 func (c *int32Converter) Zero() pref.Value { return c.def }
 
@@ -185,6 +205,13 @@
 func (c *int64Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(int64(v.Int())).Convert(c.goType)
 }
+func (c *int64Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(int64)
+	return ok
+}
+func (c *int64Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *int64Converter) New() pref.Value  { return c.def }
 func (c *int64Converter) Zero() pref.Value { return c.def }
 
@@ -202,6 +229,13 @@
 func (c *uint32Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(uint32(v.Uint())).Convert(c.goType)
 }
+func (c *uint32Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(uint32)
+	return ok
+}
+func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *uint32Converter) New() pref.Value  { return c.def }
 func (c *uint32Converter) Zero() pref.Value { return c.def }
 
@@ -219,6 +253,13 @@
 func (c *uint64Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(uint64(v.Uint())).Convert(c.goType)
 }
+func (c *uint64Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(uint64)
+	return ok
+}
+func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *uint64Converter) New() pref.Value  { return c.def }
 func (c *uint64Converter) Zero() pref.Value { return c.def }
 
@@ -236,6 +277,13 @@
 func (c *float32Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(float32(v.Float())).Convert(c.goType)
 }
+func (c *float32Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(float32)
+	return ok
+}
+func (c *float32Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *float32Converter) New() pref.Value  { return c.def }
 func (c *float32Converter) Zero() pref.Value { return c.def }
 
@@ -253,6 +301,13 @@
 func (c *float64Converter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(float64(v.Float())).Convert(c.goType)
 }
+func (c *float64Converter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(float64)
+	return ok
+}
+func (c *float64Converter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *float64Converter) New() pref.Value  { return c.def }
 func (c *float64Converter) Zero() pref.Value { return c.def }
 
@@ -276,6 +331,13 @@
 	}
 	return reflect.ValueOf(s).Convert(c.goType)
 }
+func (c *stringConverter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(string)
+	return ok
+}
+func (c *stringConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *stringConverter) New() pref.Value  { return c.def }
 func (c *stringConverter) Zero() pref.Value { return c.def }
 
@@ -296,6 +358,13 @@
 func (c *bytesConverter) GoValueOf(v pref.Value) reflect.Value {
 	return reflect.ValueOf(v.Bytes()).Convert(c.goType)
 }
+func (c *bytesConverter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().([]byte)
+	return ok
+}
+func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
 func (c *bytesConverter) New() pref.Value  { return c.def }
 func (c *bytesConverter) Zero() pref.Value { return c.def }
 
@@ -325,6 +394,15 @@
 	return reflect.ValueOf(v.Enum()).Convert(c.goType)
 }
 
+func (c *enumConverter) IsValidPB(v pref.Value) bool {
+	_, ok := v.Interface().(pref.EnumNumber)
+	return ok
+}
+
+func (c *enumConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
+
 func (c *enumConverter) New() pref.Value {
 	return c.def
 }
@@ -365,6 +443,21 @@
 	return rv
 }
 
+func (c *messageConverter) IsValidPB(v pref.Value) bool {
+	m := v.Message()
+	var rv reflect.Value
+	if u, ok := m.(Unwrapper); ok {
+		rv = reflect.ValueOf(u.ProtoUnwrap())
+	} else {
+		rv = reflect.ValueOf(m.Interface())
+	}
+	return rv.Type() == c.goType
+}
+
+func (c *messageConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
+
 func (c *messageConverter) New() pref.Value {
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 }
diff --git a/internal/impl/convert_list.go b/internal/impl/convert_list.go
index f9001b5..19748b4 100644
--- a/internal/impl/convert_list.go
+++ b/internal/impl/convert_list.go
@@ -34,6 +34,18 @@
 	return v.List().(*listReflect).v
 }
 
+func (c *listConverter) IsValidPB(v pref.Value) bool {
+	list, ok := v.Interface().(*listReflect)
+	if !ok {
+		return false
+	}
+	return list.v.Type() == c.goType
+}
+
+func (c *listConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
+
 func (c *listConverter) New() pref.Value {
 	return c.PBValueOf(reflect.New(c.goType.Elem()))
 }
diff --git a/internal/impl/convert_map.go b/internal/impl/convert_map.go
index 4182cbe..447a965 100644
--- a/internal/impl/convert_map.go
+++ b/internal/impl/convert_map.go
@@ -38,6 +38,18 @@
 	return v.Map().(*mapReflect).v
 }
 
+func (c *mapConverter) IsValidPB(v pref.Value) bool {
+	mapv, ok := v.Interface().(*mapReflect)
+	if !ok {
+		return false
+	}
+	return mapv.v.Type() == c.goType
+}
+
+func (c *mapConverter) IsValidGo(v reflect.Value) bool {
+	return v.Type() == c.goType
+}
+
 func (c *mapConverter) New() pref.Value {
 	return c.PBValueOf(reflect.MakeMap(c.goType))
 }
diff --git a/internal/impl/extension.go b/internal/impl/extension.go
index eda6ab9..8192223 100644
--- a/internal/impl/extension.go
+++ b/internal/impl/extension.go
@@ -114,6 +114,12 @@
 func (xi *ExtensionInfo) InterfaceOf(v pref.Value) interface{} {
 	return xi.lazyInit().GoValueOf(v).Interface()
 }
+func (xi *ExtensionInfo) IsValidValue(v pref.Value) bool {
+	return xi.lazyInit().IsValidPB(v)
+}
+func (xi *ExtensionInfo) IsValidInterface(v interface{}) bool {
+	return xi.lazyInit().IsValidGo(reflect.ValueOf(v))
+}
 func (xi *ExtensionInfo) GoType() reflect.Type {
 	xi.lazyInit()
 	return xi.goType
diff --git a/internal/impl/extension_test.go b/internal/impl/extension_test.go
new file mode 100644
index 0000000..d6353ed
--- /dev/null
+++ b/internal/impl/extension_test.go
@@ -0,0 +1,130 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl_test
+
+import (
+	"fmt"
+	"testing"
+
+	"github.com/golang/protobuf/proto"
+	cmp "github.com/google/go-cmp/cmp"
+	testpb "google.golang.org/protobuf/internal/testprotos/test"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+func TestExtensionType(t *testing.T) {
+	cmpOpts := cmp.Options{
+		cmp.Comparer(func(x, y proto.Message) bool {
+			return proto.Equal(x, y)
+		}),
+	}
+	for _, test := range []struct {
+		xt    pref.ExtensionType
+		value interface{}
+	}{
+		{
+			xt:    testpb.E_OptionalInt32Extension,
+			value: int32(0),
+		},
+		{
+			xt:    testpb.E_OptionalInt64Extension,
+			value: int64(0),
+		},
+		{
+			xt:    testpb.E_OptionalUint32Extension,
+			value: uint32(0),
+		},
+		{
+			xt:    testpb.E_OptionalUint64Extension,
+			value: uint64(0),
+		},
+		{
+			xt:    testpb.E_OptionalFloatExtension,
+			value: float32(0),
+		},
+		{
+			xt:    testpb.E_OptionalDoubleExtension,
+			value: float64(0),
+		},
+		{
+			xt:    testpb.E_OptionalBoolExtension,
+			value: true,
+		},
+		{
+			xt:    testpb.E_OptionalStringExtension,
+			value: "",
+		},
+		{
+			xt:    testpb.E_OptionalBytesExtension,
+			value: []byte{},
+		},
+		{
+			xt:    testpb.E_OptionalNestedMessageExtension,
+			value: &testpb.TestAllTypes_NestedMessage{},
+		},
+		{
+			xt:    testpb.E_OptionalNestedEnumExtension,
+			value: testpb.TestAllTypes_FOO,
+		},
+		{
+			xt:    testpb.E_RepeatedInt32Extension,
+			value: []int32{0},
+		},
+		{
+			xt:    testpb.E_RepeatedInt64Extension,
+			value: []int64{0},
+		},
+		{
+			xt:    testpb.E_RepeatedUint32Extension,
+			value: []uint32{0},
+		},
+		{
+			xt:    testpb.E_RepeatedUint64Extension,
+			value: []uint64{0},
+		},
+		{
+			xt:    testpb.E_RepeatedFloatExtension,
+			value: []float32{0},
+		},
+		{
+			xt:    testpb.E_RepeatedDoubleExtension,
+			value: []float64{0},
+		},
+		{
+			xt:    testpb.E_RepeatedBoolExtension,
+			value: []bool{true},
+		},
+		{
+			xt:    testpb.E_RepeatedStringExtension,
+			value: []string{""},
+		},
+		{
+			xt:    testpb.E_RepeatedBytesExtension,
+			value: [][]byte{nil},
+		},
+		{
+			xt:    testpb.E_RepeatedNestedMessageExtension,
+			value: []*testpb.TestAllTypes_NestedMessage{{}},
+		},
+		{
+			xt:    testpb.E_RepeatedNestedEnumExtension,
+			value: []testpb.TestAllTypes_NestedEnum{testpb.TestAllTypes_FOO},
+		},
+	} {
+		name := test.xt.TypeDescriptor().FullName()
+		t.Run(fmt.Sprint(name), func(t *testing.T) {
+			if !test.xt.IsValidInterface(test.value) {
+				t.Fatalf("IsValidInterface(%[1]T(%[1]v)) = false, want true", test.value)
+			}
+			v := test.xt.ValueOf(test.value)
+			if !test.xt.IsValidValue(v) {
+				t.Fatalf("IsValidValue(%[1]T(%[1]v)) = false, want true", v)
+			}
+			if got, want := test.xt.InterfaceOf(v), test.value; !cmp.Equal(got, want, cmpOpts) {
+				t.Fatalf("round trip InterfaceOf(ValueOf(x)) = %v, want %v", got, want)
+			}
+		})
+	}
+}
diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go
index e396862..223599a 100644
--- a/reflect/protoreflect/type.go
+++ b/reflect/protoreflect/type.go
@@ -484,6 +484,12 @@
 	// InterfaceOf is able to unwrap the Value further than Value.Interface
 	// as it has more type information available.
 	InterfaceOf(Value) interface{}
+
+	// IsValidValue returns whether the Value is valid to assign to the field.
+	IsValidValue(Value) bool
+
+	// IsValidInterface returns whether the input is valid to assign to the field.
+	IsValidInterface(interface{}) bool
 }
 
 // EnumDescriptor describes an enum and
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 64efbe1..7b8c8d0 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -166,8 +166,9 @@
 	m.checkField(fd)
 	switch {
 	case fd.IsExtension():
-		// Call InterfaceOf just to let the extension typecheck the value.
-		_ = fd.(pref.ExtensionTypeDescriptor).Type().InterfaceOf(v)
+		if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
+			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
+		}
 		m.ext[fd.Number()] = fd
 	case fd.IsMap():
 		if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {