types/dynamicpb: support dynamic extensions

Add a dynamicpb.NewExtensionType function to permit creating extension
types from descriptors.

Also fix a some bugs around extension field handling:
When creating a new value for an extension field, use the
ExtensionType's Zero or New method to create the value.

Ensure that prototest exercises true zero-values of fields. (i.e.,
getting a list, map, or message from an empty message rather than
creating a new empty one with NewField.)

Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index fbfe7c6..20d20c9 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -29,6 +29,12 @@
 	//
 	// If nil, TestMessage will look for extension types in the global registry.
 	ExtensionTypes []pref.ExtensionType
+
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		preg.ExtensionTypeResolver
+	}
 }
 
 // TestMessage runs the provided m through a series of tests
@@ -57,12 +63,20 @@
 	// Test round-trip marshal/unmarshal.
 	m2 := m.ProtoReflect().New().Interface()
 	populateMessage(m2.ProtoReflect(), 1, nil)
-	b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
+	for _, xt := range opts.ExtensionTypes {
+		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
+	}
+	b, err := proto.MarshalOptions{
+		AllowPartial: true,
+	}.Marshal(m2)
 	if err != nil {
 		t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
 	m3 := m.ProtoReflect().New().Interface()
-	if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
+	if err := (proto.UnmarshalOptions{
+		AllowPartial: true,
+		Resolver:     opts.Resolver,
+	}.Unmarshal(b, m3)); err != nil {
 		t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
 	if !proto.Equal(m2, m3) {
@@ -150,7 +164,7 @@
 		}
 	case fd.IsMap():
 		if got := m.Get(fd); got.Map().Len() != 0 {
-			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
+			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
 		}
 	case fd.Message() == nil:
 		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
@@ -158,6 +172,21 @@
 		}
 	}
 
+	// Set to the default value.
+	switch {
+	case fd.IsList() || fd.IsMap():
+		m.Set(fd, m.Get(fd))
+		if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
+			t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+		}
+	case fd.Message() == nil:
+		m.Set(fd, m.Get(fd))
+		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+			t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+		}
+	}
+	m.Clear(fd)
+
 	// Set to the wrong type.
 	v := pref.ValueOf("")
 	if fd.Kind() == pref.StringKind {
@@ -508,26 +537,29 @@
 func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
 	switch {
 	case fd.IsList():
-		list := m.NewField(fd).List()
 		if n == 0 {
-			return pref.ValueOf(list)
+			return m.New().Get(fd)
 		}
+		list := m.NewField(fd).List()
 		list.Append(newListElement(fd, list, 0, stack))
 		list.Append(newListElement(fd, list, minVal, stack))
 		list.Append(newListElement(fd, list, maxVal, stack))
 		list.Append(newListElement(fd, list, n, stack))
 		return pref.ValueOf(list)
 	case fd.IsMap():
-		mapv := m.NewField(fd).Map()
 		if n == 0 {
-			return pref.ValueOf(mapv)
+			return m.New().Get(fd)
 		}
+		mapv := m.NewField(fd).Map()
 		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
 		mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
 		mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
 		mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
 		return pref.ValueOf(mapv)
 	case fd.Message() != nil:
+		//if n == 0 {
+		//	return m.New().Get(fd)
+		//}
 		return populateMessage(m.NewField(fd).Message(), n, stack)
 	default:
 		return newScalarValue(fd, n)
diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go
index 7b8c8d0..06616a1 100644
--- a/types/dynamicpb/dynamic.go
+++ b/types/dynamicpb/dynamic.go
@@ -122,16 +122,22 @@
 func (m *Message) Get(fd pref.FieldDescriptor) pref.Value {
 	m.checkField(fd)
 	num := fd.Number()
-	if v, ok := m.known[num]; ok {
-		if !fd.IsExtension() || fd == m.ext[num] {
-			return v
+	if fd.IsExtension() {
+		if fd != m.ext[num] {
+			return fd.(pref.ExtensionTypeDescriptor).Type().Zero()
 		}
+		return m.known[num]
+	}
+	if v, ok := m.known[num]; ok {
+		return v
 	}
 	switch {
 	case fd.IsMap():
 		return pref.ValueOf(&dynamicMap{desc: fd})
-	case fd.Cardinality() == pref.Repeated:
+	case fd.IsList():
 		return pref.ValueOf(emptyList{desc: fd})
+	case fd.Message() != nil:
+		return pref.ValueOf(&Message{desc: fd.Message()})
 	case fd.Kind() == pref.BytesKind:
 		return pref.ValueOf(append([]byte(nil), fd.Default().Bytes()...))
 	default:
@@ -143,15 +149,23 @@
 // See protoreflect.Message for details.
 func (m *Message) Mutable(fd pref.FieldDescriptor) pref.Value {
 	m.checkField(fd)
-	num := fd.Number()
-	if v, ok := m.known[num]; ok {
-		if !fd.IsExtension() || fd == m.ext[num] {
-			return v
-		}
-	}
 	if !fd.IsMap() && !fd.IsList() && fd.Message() == nil {
 		panic(errors.New("%v: getting mutable reference to non-composite type", fd.FullName()))
 	}
+	if m.known == nil {
+		panic(errors.New("%v: modification of read-only message", fd.FullName()))
+	}
+	num := fd.Number()
+	if fd.IsExtension() {
+		if fd != m.ext[num] {
+			m.ext[num] = fd
+			m.known[num] = fd.(pref.ExtensionTypeDescriptor).Type().New()
+		}
+		return m.known[num]
+	}
+	if v, ok := m.known[num]; ok {
+		return v
+	}
 	m.clearOtherOneofFields(fd)
 	m.known[num] = m.NewField(fd)
 	if fd.IsExtension() {
@@ -164,22 +178,16 @@
 // See protoreflect.Message for details.
 func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) {
 	m.checkField(fd)
-	switch {
-	case fd.IsExtension():
+	if m.known == nil {
+		panic(errors.New("%v: modification of read-only message", fd.FullName()))
+	}
+	if fd.IsExtension() {
 		if !fd.(pref.ExtensionTypeDescriptor).Type().IsValidValue(v) {
 			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
 		}
 		m.ext[fd.Number()] = fd
-	case fd.IsMap():
-		if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
-			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
-		}
-	case fd.IsList():
-		if list, ok := v.Interface().(*dynamicList); !ok || list.desc != fd {
-			panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
-		}
-	default:
-		typecheckSingular(fd, v)
+	} else {
+		typecheck(fd, v)
 	}
 	m.clearOtherOneofFields(fd)
 	m.known[fd.Number()] = v
@@ -251,6 +259,9 @@
 // SetUnknown sets the raw unknown fields.
 // See protoreflect.Message for details.
 func (m *Message) SetUnknown(r pref.RawFields) {
+	if m.known == nil {
+		panic(errors.New("%v: modification of read-only message", m.desc.FullName()))
+	}
 	m.unknown = r
 }
 
@@ -406,7 +417,43 @@
 	return true
 }
 
+func typecheck(fd pref.FieldDescriptor, v pref.Value) {
+	if err := typeIsValid(fd, v); err != nil {
+		panic(err)
+	}
+}
+
+func typeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
+	switch {
+	case fd.IsMap():
+		if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd {
+			return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
+		}
+		return nil
+	case fd.IsList():
+		switch list := v.Interface().(type) {
+		case *dynamicList:
+			if list.desc == fd {
+				return nil
+			}
+		case emptyList:
+			if list.desc == fd {
+				return nil
+			}
+		}
+		return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
+	default:
+		return singularTypeIsValid(fd, v)
+	}
+}
+
 func typecheckSingular(fd pref.FieldDescriptor, v pref.Value) {
+	if err := singularTypeIsValid(fd, v); err != nil {
+		panic(err)
+	}
+}
+
+func singularTypeIsValid(fd pref.FieldDescriptor, v pref.Value) error {
 	vi := v.Interface()
 	var ok bool
 	switch fd.Kind() {
@@ -435,12 +482,16 @@
 		var m pref.Message
 		m, ok = vi.(pref.Message)
 		if ok && m.Descriptor().FullName() != fd.Message().FullName() {
-			panic(errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName()))
+			return errors.New("%v: assigning invalid message type %v", fd.FullName(), m.Descriptor().FullName())
+		}
+		if dm, ok := vi.(*Message); ok && dm.known == nil {
+			return errors.New("%v: assigning invalid zero-value message", fd.FullName())
 		}
 	}
 	if !ok {
-		panic(errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface()))
+		return errors.New("%v: assigning invalid type %T", fd.FullName(), v.Interface())
 	}
+	return nil
 }
 
 func newListEntry(fd pref.FieldDescriptor) pref.Value {
@@ -470,3 +521,102 @@
 	}
 	panic(errors.New("%v: unknown kind %v", fd.FullName(), fd.Kind()))
 }
+
+// extensionType is a dynamic protoreflect.ExtensionType.
+type extensionType struct {
+	desc extensionTypeDescriptor
+}
+
+// NewExtensionType creates a new ExtensionType with the provided descriptor.
+//
+// Dynamic ExtensionTypes with the same descriptor compare as equal. That is,
+// if xd1 == xd2, then NewExtensionType(xd1) == NewExtensionType(xd2).
+//
+// The InterfaceOf and ValueOf methods of the extension type are defined as:
+//
+//	func (xt extensionType) ValueOf(iv interface{}) protoreflect.Value {
+//		return protoreflect.ValueOf(iv)
+//	}
+//
+//	func (xt extensionType) InterfaceOf(v protoreflect.Value) interface{} {
+//		return v.Interface()
+//	}
+//
+// The Go type used by the proto.GetExtension and proto.SetExtension functions
+// is determined by these methods, and is therefore equivalent to the Go type
+// used to represent a protoreflect.Value. See the protoreflect.Value
+// documentation for more details.
+func NewExtensionType(desc pref.ExtensionDescriptor) pref.ExtensionType {
+	if xt, ok := desc.(pref.ExtensionTypeDescriptor); ok {
+		desc = xt.Descriptor()
+	}
+	return extensionType{extensionTypeDescriptor{desc}}
+}
+
+func (xt extensionType) New() pref.Value {
+	switch {
+	case xt.desc.IsMap():
+		return pref.ValueOf(&dynamicMap{
+			desc: xt.desc,
+			mapv: make(map[interface{}]pref.Value),
+		})
+	case xt.desc.IsList():
+		return pref.ValueOf(&dynamicList{desc: xt.desc})
+	case xt.desc.Message() != nil:
+		return pref.ValueOf(New(xt.desc.Message()))
+	default:
+		return xt.desc.Default()
+	}
+}
+
+func (xt extensionType) Zero() pref.Value {
+	switch {
+	case xt.desc.IsMap():
+		return pref.ValueOf(&dynamicMap{desc: xt.desc})
+	case xt.desc.Cardinality() == pref.Repeated:
+		return pref.ValueOf(emptyList{desc: xt.desc})
+	case xt.desc.Message() != nil:
+		return pref.ValueOf(&Message{desc: xt.desc.Message()})
+	default:
+		return xt.desc.Default()
+	}
+}
+
+func (xt extensionType) GoType() reflect.Type {
+	return reflect.TypeOf(xt.InterfaceOf(xt.New()))
+}
+
+func (xt extensionType) TypeDescriptor() pref.ExtensionTypeDescriptor {
+	return xt.desc
+}
+
+func (xt extensionType) ValueOf(iv interface{}) pref.Value {
+	v := pref.ValueOf(iv)
+	typecheck(xt.desc, v)
+	return v
+}
+
+func (xt extensionType) InterfaceOf(v pref.Value) interface{} {
+	typecheck(xt.desc, v)
+	return v.Interface()
+}
+
+func (xt extensionType) IsValidInterface(iv interface{}) bool {
+	return typeIsValid(xt.desc, pref.ValueOf(iv)) == nil
+}
+
+func (xt extensionType) IsValidValue(v pref.Value) bool {
+	return typeIsValid(xt.desc, v) == nil
+}
+
+type extensionTypeDescriptor struct {
+	pref.ExtensionDescriptor
+}
+
+func (xt extensionTypeDescriptor) Type() pref.ExtensionType {
+	return extensionType{xt}
+}
+
+func (xt extensionTypeDescriptor) Descriptor() pref.ExtensionDescriptor {
+	return xt.ExtensionDescriptor
+}
diff --git a/types/dynamicpb/dynamic_test.go b/types/dynamicpb/dynamic_test.go
index 5a7db95..a4696d4 100644
--- a/types/dynamicpb/dynamic_test.go
+++ b/types/dynamicpb/dynamic_test.go
@@ -8,6 +8,8 @@
 	"testing"
 
 	"google.golang.org/protobuf/proto"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+	preg "google.golang.org/protobuf/reflect/protoregistry"
 	"google.golang.org/protobuf/testing/prototest"
 	"google.golang.org/protobuf/types/dynamicpb"
 
@@ -24,3 +26,37 @@
 		prototest.TestMessage(t, dynamicpb.New(message.ProtoReflect().Descriptor()), prototest.MessageOptions{})
 	}
 }
+
+func TestDynamicExtensions(t *testing.T) {
+	file, err := preg.GlobalFiles.FindFileByPath("test/ext.proto")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	md := (&testpb.TestAllExtensions{}).ProtoReflect().Descriptor()
+	opts := prototest.MessageOptions{
+		Resolver: extResolver{},
+	}
+	for i := 0; i < file.Extensions().Len(); i++ {
+		opts.ExtensionTypes = append(opts.ExtensionTypes, dynamicpb.NewExtensionType(file.Extensions().Get(i)))
+	}
+	prototest.TestMessage(t, dynamicpb.New(md), opts)
+}
+
+type extResolver struct{}
+
+func (extResolver) FindExtensionByName(field pref.FullName) (pref.ExtensionType, error) {
+	xt, err := preg.GlobalTypes.FindExtensionByName(field)
+	if err != nil {
+		return nil, err
+	}
+	return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
+}
+
+func (extResolver) FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error) {
+	xt, err := preg.GlobalTypes.FindExtensionByNumber(message, field)
+	if err != nil {
+		return nil, err
+	}
+	return dynamicpb.NewExtensionType(xt.TypeDescriptor().Descriptor()), nil
+}