all: improve extension validation

Changes made:
* Ensure protoreflect.ExtensionType.IsValidInterface never panics,
especially if given a nil interface value.
* Have protoreflect.ExtensionType.IsValid{Interface,Value} only
perform type-checks. It does not do value checks (i.e., whether the
value itself is valid). Value validity is left to when an actual
protoreflect.Message.Set operation is performed.
* Add special-casing on proto.SetExtension to treat an invalid
message or list as functionally equivalent to Clear. This is to
be more consistent with the legacy SetExtension implementation
which never panicked when given such values.
* Add special-casing on proto.HasExtension to treat a mismatched
extension descriptor as simply not being present in the message.
This is also to be more consistent with the legacy HasExtension
implementation which did the same thing.

Change-Id: Idf0419abf27b9f85d9b92bd2ff8088e25b7990cc
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/229558
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/convert.go b/internal/impl/convert.go
index 9fc384a..36a90df 100644
--- a/internal/impl/convert.go
+++ b/internal/impl/convert.go
@@ -162,7 +162,7 @@
 	return ok
 }
 func (c *boolConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *boolConverter) New() pref.Value  { return c.def }
 func (c *boolConverter) Zero() pref.Value { return c.def }
@@ -186,7 +186,7 @@
 	return ok
 }
 func (c *int32Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *int32Converter) New() pref.Value  { return c.def }
 func (c *int32Converter) Zero() pref.Value { return c.def }
@@ -210,7 +210,7 @@
 	return ok
 }
 func (c *int64Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *int64Converter) New() pref.Value  { return c.def }
 func (c *int64Converter) Zero() pref.Value { return c.def }
@@ -234,7 +234,7 @@
 	return ok
 }
 func (c *uint32Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *uint32Converter) New() pref.Value  { return c.def }
 func (c *uint32Converter) Zero() pref.Value { return c.def }
@@ -258,7 +258,7 @@
 	return ok
 }
 func (c *uint64Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *uint64Converter) New() pref.Value  { return c.def }
 func (c *uint64Converter) Zero() pref.Value { return c.def }
@@ -282,7 +282,7 @@
 	return ok
 }
 func (c *float32Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *float32Converter) New() pref.Value  { return c.def }
 func (c *float32Converter) Zero() pref.Value { return c.def }
@@ -306,7 +306,7 @@
 	return ok
 }
 func (c *float64Converter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *float64Converter) New() pref.Value  { return c.def }
 func (c *float64Converter) Zero() pref.Value { return c.def }
@@ -336,7 +336,7 @@
 	return ok
 }
 func (c *stringConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *stringConverter) New() pref.Value  { return c.def }
 func (c *stringConverter) Zero() pref.Value { return c.def }
@@ -363,7 +363,7 @@
 	return ok
 }
 func (c *bytesConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 func (c *bytesConverter) New() pref.Value  { return c.def }
 func (c *bytesConverter) Zero() pref.Value { return c.def }
@@ -400,7 +400,7 @@
 }
 
 func (c *enumConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 
 func (c *enumConverter) New() pref.Value {
@@ -455,7 +455,7 @@
 }
 
 func (c *messageConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 
 func (c *messageConverter) New() pref.Value {
diff --git a/internal/impl/convert_list.go b/internal/impl/convert_list.go
index fe9384a..6fccab5 100644
--- a/internal/impl/convert_list.go
+++ b/internal/impl/convert_list.go
@@ -22,7 +22,7 @@
 }
 
 type listConverter struct {
-	goType reflect.Type
+	goType reflect.Type // []T
 	c      Converter
 }
 
@@ -48,11 +48,11 @@
 	if !ok {
 		return false
 	}
-	return list.v.Type().Elem() == c.goType && list.IsValid()
+	return list.v.Type().Elem() == c.goType
 }
 
 func (c *listConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 
 func (c *listConverter) New() pref.Value {
@@ -64,7 +64,7 @@
 }
 
 type listPtrConverter struct {
-	goType reflect.Type
+	goType reflect.Type // *[]T
 	c      Converter
 }
 
@@ -88,7 +88,7 @@
 }
 
 func (c *listPtrConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 
 func (c *listPtrConverter) New() pref.Value {
diff --git a/internal/impl/convert_map.go b/internal/impl/convert_map.go
index 3ef36d3..de06b25 100644
--- a/internal/impl/convert_map.go
+++ b/internal/impl/convert_map.go
@@ -12,7 +12,7 @@
 )
 
 type mapConverter struct {
-	goType           reflect.Type
+	goType           reflect.Type // map[K]V
 	keyConv, valConv Converter
 }
 
@@ -43,11 +43,11 @@
 	if !ok {
 		return false
 	}
-	return mapv.v.Type() == c.goType && mapv.IsValid()
+	return mapv.v.Type() == c.goType
 }
 
 func (c *mapConverter) IsValidGo(v reflect.Value) bool {
-	return v.Type() == c.goType
+	return v.IsValid() && v.Type() == c.goType
 }
 
 func (c *mapConverter) New() pref.Value {
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index aac55ee..28114ff 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -170,6 +170,8 @@
 		return x.Value().List().Len() > 0
 	case xd.IsMap():
 		return x.Value().Map().Len() > 0
+	case xd.Message() != nil:
+		return x.Value().Message().IsValid()
 	}
 	return true
 }
@@ -186,15 +188,28 @@
 	return xt.Zero()
 }
 func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
-	if !xt.IsValidValue(v) {
+	xd := xt.TypeDescriptor()
+	isValid := true
+	switch {
+	case !xt.IsValidValue(v):
+		isValid = false
+	case xd.IsList():
+		isValid = v.List().IsValid()
+	case xd.IsMap():
+		isValid = v.Map().IsValid()
+	case xd.Message() != nil:
+		isValid = v.Message().IsValid()
+	}
+	if !isValid {
 		panic(fmt.Sprintf("%v: assigning invalid value", xt.TypeDescriptor().FullName()))
 	}
+
 	if *m == nil {
 		*m = make(map[int32]ExtensionField)
 	}
 	var x ExtensionField
 	x.Set(xt, v)
-	(*m)[int32(xt.TypeDescriptor().Number())] = x
+	(*m)[int32(xd.Number())] = x
 }
 func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
 	xd := xt.TypeDescriptor()
diff --git a/proto/extension.go b/proto/extension.go
index 94af03f..5f293cd 100644
--- a/proto/extension.go
+++ b/proto/extension.go
@@ -9,40 +9,65 @@
 )
 
 // HasExtension reports whether an extension field is populated.
-// It panics if ext does not extend m.
-func HasExtension(m Message, ext protoreflect.ExtensionType) bool {
+// It returns false if m is invalid or if xt does not extend m.
+func HasExtension(m Message, xt protoreflect.ExtensionType) bool {
 	// Treat nil message interface as an empty message; no populated fields.
 	if m == nil {
 		return false
 	}
 
-	return m.ProtoReflect().Has(ext.TypeDescriptor())
+	// As a special-case, we reports invalid or mismatching descriptors
+	// as always not being populated (since they aren't).
+	if xt == nil || m.ProtoReflect().Descriptor() != xt.TypeDescriptor().ContainingMessage() {
+		return false
+	}
+
+	return m.ProtoReflect().Has(xt.TypeDescriptor())
 }
 
 // ClearExtension clears an extension field such that subsequent
 // HasExtension calls return false.
-// It panics if ext does not extend m.
-func ClearExtension(m Message, ext protoreflect.ExtensionType) {
-	m.ProtoReflect().Clear(ext.TypeDescriptor())
+// It panics if m is invalid or if xt does not extend m.
+func ClearExtension(m Message, xt protoreflect.ExtensionType) {
+	m.ProtoReflect().Clear(xt.TypeDescriptor())
 }
 
 // GetExtension retrieves the value for an extension field.
 // If the field is unpopulated, it returns the default value for
 // scalars and an immutable, empty value for lists or messages.
-// It panics if ext does not extend m.
-func GetExtension(m Message, ext protoreflect.ExtensionType) interface{} {
+// It panics if xt does not extend m.
+func GetExtension(m Message, xt protoreflect.ExtensionType) interface{} {
 	// Treat nil message interface as an empty message; return the default.
 	if m == nil {
-		return ext.InterfaceOf(ext.Zero())
+		return xt.InterfaceOf(xt.Zero())
 	}
 
-	return ext.InterfaceOf(m.ProtoReflect().Get(ext.TypeDescriptor()))
+	return xt.InterfaceOf(m.ProtoReflect().Get(xt.TypeDescriptor()))
 }
 
 // SetExtension stores the value of an extension field.
-// It panics if ext does not extend m or if value type is invalid for the field.
-func SetExtension(m Message, ext protoreflect.ExtensionType, value interface{}) {
-	m.ProtoReflect().Set(ext.TypeDescriptor(), ext.ValueOf(value))
+// It panics if m is invalid, xt does not extend m, or if type of v
+// is invalid for the specified extension field.
+func SetExtension(m Message, xt protoreflect.ExtensionType, v interface{}) {
+	xd := xt.TypeDescriptor()
+	pv := xt.ValueOf(v)
+
+	// Specially treat an invalid list, map, or message as clear.
+	isValid := true
+	switch {
+	case xd.IsList():
+		isValid = pv.List().IsValid()
+	case xd.IsMap():
+		isValid = pv.Map().IsValid()
+	case xd.Message() != nil:
+		isValid = pv.Message().IsValid()
+	}
+	if !isValid {
+		m.ProtoReflect().Clear(xd)
+		return
+	}
+
+	m.ProtoReflect().Set(xd, pv)
 }
 
 // RangeExtensions iterates over every populated extension field in m in an
diff --git a/proto/extension_test.go b/proto/extension_test.go
index 113160f..212a6f7 100644
--- a/proto/extension_test.go
+++ b/proto/extension_test.go
@@ -6,12 +6,14 @@
 
 import (
 	"fmt"
+	"reflect"
 	"sync"
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
 
 	"google.golang.org/protobuf/proto"
+	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	"google.golang.org/protobuf/runtime/protoimpl"
 	"google.golang.org/protobuf/testing/protocmp"
@@ -69,6 +71,139 @@
 	}
 }
 
+func TestIsValid(t *testing.T) {
+	tests := []struct {
+		xt   protoreflect.ExtensionType
+		vi   interface{}
+		want bool
+	}{
+		{testpb.E_OptionalBool, nil, false},
+		{testpb.E_OptionalBool, bool(true), true},
+		{testpb.E_OptionalBool, new(bool), false},
+		{testpb.E_OptionalInt32, nil, false},
+		{testpb.E_OptionalInt32, int32(0), true},
+		{testpb.E_OptionalInt32, new(int32), false},
+		{testpb.E_OptionalInt64, nil, false},
+		{testpb.E_OptionalInt64, int64(0), true},
+		{testpb.E_OptionalInt64, new(int64), false},
+		{testpb.E_OptionalUint32, nil, false},
+		{testpb.E_OptionalUint32, uint32(0), true},
+		{testpb.E_OptionalUint32, new(uint32), false},
+		{testpb.E_OptionalUint64, nil, false},
+		{testpb.E_OptionalUint64, uint64(0), true},
+		{testpb.E_OptionalUint64, new(uint64), false},
+		{testpb.E_OptionalFloat, nil, false},
+		{testpb.E_OptionalFloat, float32(0), true},
+		{testpb.E_OptionalFloat, new(float32), false},
+		{testpb.E_OptionalDouble, nil, false},
+		{testpb.E_OptionalDouble, float64(0), true},
+		{testpb.E_OptionalDouble, new(float32), false},
+		{testpb.E_OptionalString, nil, false},
+		{testpb.E_OptionalString, string(""), true},
+		{testpb.E_OptionalString, new(string), false},
+		{testpb.E_OptionalNestedEnum, nil, false},
+		{testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ, true},
+		{testpb.E_OptionalNestedEnum, testpb.TestAllTypes_BAZ.Enum(), false},
+		{testpb.E_OptionalNestedMessage, nil, false},
+		{testpb.E_OptionalNestedMessage, (*testpb.TestAllExtensions_NestedMessage)(nil), true},
+		{testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions_NestedMessage), true},
+		{testpb.E_OptionalNestedMessage, new(testpb.TestAllExtensions), false},
+		{testpb.E_RepeatedBool, nil, false},
+		{testpb.E_RepeatedBool, []bool(nil), true},
+		{testpb.E_RepeatedBool, []bool{}, true},
+		{testpb.E_RepeatedBool, []bool{false}, true},
+		{testpb.E_RepeatedBool, []*bool{}, false},
+		{testpb.E_RepeatedInt32, nil, false},
+		{testpb.E_RepeatedInt32, []int32(nil), true},
+		{testpb.E_RepeatedInt32, []int32{}, true},
+		{testpb.E_RepeatedInt32, []int32{0}, true},
+		{testpb.E_RepeatedInt32, []*int32{}, false},
+		{testpb.E_RepeatedInt64, nil, false},
+		{testpb.E_RepeatedInt64, []int64(nil), true},
+		{testpb.E_RepeatedInt64, []int64{}, true},
+		{testpb.E_RepeatedInt64, []int64{0}, true},
+		{testpb.E_RepeatedInt64, []*int64{}, false},
+		{testpb.E_RepeatedUint32, nil, false},
+		{testpb.E_RepeatedUint32, []uint32(nil), true},
+		{testpb.E_RepeatedUint32, []uint32{}, true},
+		{testpb.E_RepeatedUint32, []uint32{0}, true},
+		{testpb.E_RepeatedUint32, []*uint32{}, false},
+		{testpb.E_RepeatedUint64, nil, false},
+		{testpb.E_RepeatedUint64, []uint64(nil), true},
+		{testpb.E_RepeatedUint64, []uint64{}, true},
+		{testpb.E_RepeatedUint64, []uint64{0}, true},
+		{testpb.E_RepeatedUint64, []*uint64{}, false},
+		{testpb.E_RepeatedFloat, nil, false},
+		{testpb.E_RepeatedFloat, []float32(nil), true},
+		{testpb.E_RepeatedFloat, []float32{}, true},
+		{testpb.E_RepeatedFloat, []float32{0}, true},
+		{testpb.E_RepeatedFloat, []*float32{}, false},
+		{testpb.E_RepeatedDouble, nil, false},
+		{testpb.E_RepeatedDouble, []float64(nil), true},
+		{testpb.E_RepeatedDouble, []float64{}, true},
+		{testpb.E_RepeatedDouble, []float64{0}, true},
+		{testpb.E_RepeatedDouble, []*float64{}, false},
+		{testpb.E_RepeatedString, nil, false},
+		{testpb.E_RepeatedString, []string(nil), true},
+		{testpb.E_RepeatedString, []string{}, true},
+		{testpb.E_RepeatedString, []string{""}, true},
+		{testpb.E_RepeatedString, []*string{}, false},
+		{testpb.E_RepeatedNestedEnum, nil, false},
+		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum(nil), true},
+		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{}, true},
+		{testpb.E_RepeatedNestedEnum, []testpb.TestAllTypes_NestedEnum{0}, true},
+		{testpb.E_RepeatedNestedEnum, []*testpb.TestAllTypes_NestedEnum{}, false},
+		{testpb.E_RepeatedNestedMessage, nil, false},
+		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage(nil), true},
+		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{}, true},
+		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions_NestedMessage{{}}, true},
+		{testpb.E_RepeatedNestedMessage, []*testpb.TestAllExtensions{}, false},
+	}
+
+	for _, tt := range tests {
+		// Check the results of IsValidInterface.
+		got := tt.xt.IsValidInterface(tt.vi)
+		if got != tt.want {
+			t.Errorf("%v.IsValidInterface() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
+		}
+		if !got {
+			continue
+		}
+
+		// Set the extension value and verify the results of Has.
+		wantHas := true
+		pv := tt.xt.ValueOf(tt.vi)
+		switch v := pv.Interface().(type) {
+		case protoreflect.List:
+			wantHas = v.Len() > 0
+		case protoreflect.Message:
+			wantHas = v.IsValid()
+		}
+		m := &testpb.TestAllExtensions{}
+		proto.SetExtension(m, tt.xt, tt.vi)
+		gotHas := proto.HasExtension(m, tt.xt)
+		if gotHas != wantHas {
+			t.Errorf("HasExtension(%q) = %v, want %v", tt.xt.TypeDescriptor().FullName(), gotHas, wantHas)
+		}
+
+		// Check consistency of IsValidInterface and IsValidValue.
+		got = tt.xt.IsValidValue(pv)
+		if got != tt.want {
+			t.Errorf("%v.IsValidValue() = %v, want %v", tt.xt.TypeDescriptor().FullName(), got, tt.want)
+		}
+		if !got {
+			continue
+		}
+
+		// Use of reflect.DeepEqual is intentional.
+		// We really do want to ensure that the memory layout is identical.
+		vi := tt.xt.InterfaceOf(pv)
+		if !reflect.DeepEqual(vi, tt.vi) {
+			t.Errorf("InterfaceOf(ValueOf(...)) round-trip mismatch: got %v, want %v", vi, tt.vi)
+		}
+	}
+}
+
 func TestExtensionRanger(t *testing.T) {
 	want := map[pref.ExtensionType]interface{}{
 		testpb.E_OptionalInt32:         int32(5),
diff --git a/proto/nil_test.go b/proto/nil_test.go
index 9d13b2b..29d259d 100644
--- a/proto/nil_test.go
+++ b/proto/nil_test.go
@@ -97,7 +97,6 @@
 	}, {
 		label: "HasExtension",
 		test:  func() { proto.HasExtension(nilMsg, nil) },
-		panic: true,
 	}, {
 		label: "HasExtension",
 		test:  func() { proto.HasExtension(nilMsg, extType) },
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 2a41fc8..7046ef2 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -215,7 +215,18 @@
 		panic(errors.New("%v: modification of read-only message", fd.FullName()))
 	}
 	if fd.IsExtension() {
-		if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
+		isValid := true
+		switch {
+		case !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v):
+			isValid = false
+		case fd.IsList():
+			isValid = v.List().IsValid()
+		case fd.IsMap():
+			isValid = v.Map().IsValid()
+		case fd.Message() != nil:
+			isValid = v.Message().IsValid()
+		}
+		if !isValid {
 			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
 		}
 		m.ext[fd.Number()] = fd
@@ -467,6 +478,8 @@
 
 func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
 	switch {
+	case !v.IsValid():
+		return errors.New("%v: assigning invalid value", fd.FullName())
 	case fd.IsMap():
 		if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd || !mapv.IsValid() {
 			return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())