reflect/prototype: simplify Go type descriptor constructors

The Go type descriptors protoreflect.{Enum,Message,Extension}Type are simple
wrappers over protoreflect.{Enum,Message,Extension}Descriptor with a small
number of additional methods. It is very unlikely that more will be added in
the near future.

For this reason, construct the types directly using arguments to the constructor
function, as opposed to taking in another struct (which was originally done
to provide flexibility in-case we needed more fields).

Furthmore, rename GoNew and New.

Change-Id: Ic7fb5bc250cdb2761ae03b388b5147ff50f37d15
Reviewed-on: https://go-review.googlesource.com/c/148822
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/internal/impl/message.go b/internal/impl/message.go
index c406b3d..c6f247e 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -66,12 +66,9 @@
 		//
 		// Generated code ensures that this property holds.
 		if _, ok := p.(pref.ProtoMessage); !ok {
-			mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{
-				MessageDescriptor: mi.Desc,
-				New: func(pref.MessageType) pref.ProtoMessage {
-					p := reflect.New(t.Elem()).Interface()
-					return (*message)(mi.dataTypeOf(p))
-				},
+			mi.pbType = ptype.GoMessage(mi.Desc, func(pref.MessageType) pref.ProtoMessage {
+				p := reflect.New(t.Elem()).Interface()
+				return (*message)(mi.dataTypeOf(p))
 			})
 		}
 
diff --git a/internal/value/convert.go b/internal/value/convert.go
index 2b91171..cd3402a 100644
--- a/internal/value/convert.go
+++ b/internal/value/convert.go
@@ -111,7 +111,7 @@
 					return pref.ValueOf(e.ProtoReflect().Number())
 				},
 				toGo: func(v pref.Value) reflect.Value {
-					rv := reflect.ValueOf(et.GoNew(v.Enum()))
+					rv := reflect.ValueOf(et.New(v.Enum()))
 					if rv.Type() != t {
 						panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t))
 					}
@@ -153,7 +153,7 @@
 					return rv
 				},
 				newMessage: func() pref.Message {
-					return mt.GoNew().ProtoReflect()
+					return mt.New().ProtoReflect()
 				},
 			}
 		}
diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go
index 6e3ff11..1fcc575 100644
--- a/reflect/protoreflect/type.go
+++ b/reflect/protoreflect/type.go
@@ -6,8 +6,6 @@
 
 import "reflect"
 
-// TODO: Rename GoNew as New for MessageType, EnumType, and ExtensionType?
-
 // TODO: For all ByX methods (e.g., ByName, ByJSONName, ByNumber, etc),
 // should they use a (v, ok) signature for the return value?
 
@@ -259,12 +257,12 @@
 type MessageType interface {
 	MessageDescriptor
 
-	// GoNew returns a newly allocated empty message.
-	GoNew() ProtoMessage
+	// New returns a newly allocated empty message.
+	New() ProtoMessage
 
 	// GoType returns the Go type of the allocated message.
 	//
-	// Invariant: t.GoType() == reflect.TypeOf(t.GoNew())
+	// Invariant: t.GoType() == reflect.TypeOf(t.New())
 	GoType() reflect.Type
 }
 
@@ -437,15 +435,15 @@
 type ExtensionType interface {
 	ExtensionDescriptor
 
-	// GoNew returns a new value for the field.
+	// New returns a new value for the field.
 	// For scalars, this returns the default value in native Go form.
-	GoNew() interface{}
+	New() interface{}
 
 	// GoType returns the Go type of the field value.
 	//
 	// Invariants:
-	//	t.GoType() == reflect.TypeOf(t.GoNew())
-	//	t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.GoNew())))
+	//	t.GoType() == reflect.TypeOf(t.New())
+	//	t.GoType() == reflect.TypeOf(t.InterfaceOf(t.ValueOf(t.New())))
 	GoType() reflect.Type
 
 	// TODO: How do we reconcile GoType with the existing extension API,
@@ -487,12 +485,12 @@
 type EnumType interface {
 	EnumDescriptor
 
-	// GoNew returns an instance of this enum type with its value set to n.
-	GoNew(n EnumNumber) ProtoEnum
+	// New returns an instance of this enum type with its value set to n.
+	New(n EnumNumber) ProtoEnum
 
 	// GoType returns the Go type of the enum value.
 	//
-	// Invariants: t.GoType() == reflect.TypeOf(t.GoNew(0))
+	// Invariants: t.GoType() == reflect.TypeOf(t.New(0))
 	GoType() reflect.Type
 }
 
diff --git a/reflect/prototype/go_type.go b/reflect/prototype/go_type.go
index 875d9c5..10f1f54 100644
--- a/reflect/prototype/go_type.go
+++ b/reflect/prototype/go_type.go
@@ -13,125 +13,74 @@
 	"github.com/golang/protobuf/v2/reflect/protoreflect"
 )
 
-// GoEnum is a constructor for a protoreflect.EnumType.
-type GoEnum struct {
-	protoreflect.EnumDescriptor
-
-	// New returns a concrete proto.Enum value with the given enum number.
-	// The constructor must return the same concrete type for each invocation.
-	New func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum
-
-	once   sync.Once
-	goType reflect.Type
-}
-type goEnum struct{ *GoEnum }
-
-// NewGoEnum creates a new protoreflect.EnumType.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoEnum(t *GoEnum) protoreflect.EnumType {
-	if t.IsPlaceholder() {
+// GoEnum creates a new protoreflect.EnumType by combining the provided
+// protoreflect.EnumDescriptor with the provided constructor function.
+func GoEnum(ed protoreflect.EnumDescriptor, fn func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum) protoreflect.EnumType {
+	if ed.IsPlaceholder() {
 		panic("enum descriptor must not be a placeholder")
 	}
-	if t.New == nil {
-		panic("invalid nil constructor for enum kind")
-	}
-	return goEnum{t}
+	t := &goEnum{EnumDescriptor: ed, new: fn}
+	t.typ = reflect.TypeOf(fn(t, 0))
+	return t
 }
-func (p goEnum) GoNew(n protoreflect.EnumNumber) protoreflect.ProtoEnum {
-	e := p.New(p, n)
-	p.once.Do(func() { p.goType = reflect.TypeOf(e) })
-	if p.goType != reflect.TypeOf(e) {
-		panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, p.goType))
+
+type goEnum struct {
+	protoreflect.EnumDescriptor
+	typ reflect.Type
+	new func(protoreflect.EnumType, protoreflect.EnumNumber) protoreflect.ProtoEnum
+}
+
+func (t *goEnum) GoType() reflect.Type {
+	return t.typ
+}
+func (t *goEnum) New(n protoreflect.EnumNumber) protoreflect.ProtoEnum {
+	e := t.new(t, n)
+	if t.typ != reflect.TypeOf(e) {
+		panic(fmt.Sprintf("mismatching types for enum: got %T, want %v", e, t.typ))
 	}
 	return e
 }
-func (p goEnum) GoType() reflect.Type {
-	p.once.Do(func() { p.goType = reflect.TypeOf(p.New(p, 0)) })
-	return p.goType
-}
 
-// GoMessage is a constructor for a protoreflect.MessageType.
-type GoMessage struct {
-	protoreflect.MessageDescriptor
-
-	// New returns a new empty proto.Message instance.
-	// The constructor must return the same concrete type for each invocation.
-	New func(protoreflect.MessageType) protoreflect.ProtoMessage
-
-	once   sync.Once
-	goType reflect.Type
-}
-type goMessage struct{ *GoMessage }
-
-// NewGoMessage creates a new protoreflect.MessageType.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoMessage(t *GoMessage) protoreflect.MessageType {
-	if t.IsPlaceholder() {
+// GoMessage creates a new protoreflect.MessageType by combining the provided
+// protoreflect.MessageDescriptor with the provided constructor function.
+func GoMessage(md protoreflect.MessageDescriptor, fn func(protoreflect.MessageType) protoreflect.ProtoMessage) protoreflect.MessageType {
+	if md.IsPlaceholder() {
 		panic("message descriptor must not be a placeholder")
 	}
-	if t.New == nil {
-		panic("invalid nil constructor for message kind")
-	}
-	return goMessage{t}
+	t := &goMessage{MessageDescriptor: md, new: fn}
+	t.typ = reflect.TypeOf(fn(t))
+	return t
 }
-func (p goMessage) GoNew() protoreflect.ProtoMessage {
-	m := p.New(p)
-	p.once.Do(func() { p.goType = reflect.TypeOf(m) })
-	if p.goType != reflect.TypeOf(m) {
-		panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, p.goType))
+
+type goMessage struct {
+	protoreflect.MessageDescriptor
+	typ reflect.Type
+	new func(protoreflect.MessageType) protoreflect.ProtoMessage
+}
+
+func (t *goMessage) GoType() reflect.Type {
+	return t.typ
+}
+func (t *goMessage) New() protoreflect.ProtoMessage {
+	m := t.new(t)
+	if t.typ != reflect.TypeOf(m) {
+		panic(fmt.Sprintf("mismatching types for message: got %T, want %v", m, t.typ))
 	}
 	return m
 }
-func (p goMessage) GoType() reflect.Type {
-	p.once.Do(func() { p.goType = reflect.TypeOf(p.New(p)) })
-	return p.goType
-}
 
-// GoExtension is a constructor for a protoreflect.ExtensionType.
-type GoExtension struct {
-	protoreflect.ExtensionDescriptor
-
-	// NewEnum returns a concrete proto.Enum value with the given enum number.
-	// The constructor must be provided if protoreflect.ExtensionDescriptor.Kind
-	// is protoreflect.EnumKind.
-	//
-	// The returned enum must represent an protoreflect.EnumDescriptor
-	// that matches protoreflect.ExtensionDescriptor.EnumType.
-	NewEnum func(protoreflect.EnumNumber) protoreflect.ProtoEnum
-
-	// NewMessage returns a new empty proto.Message instance.
-	// The constructor must be provided if protoreflect.ExtensionDescriptor.Kind
-	// is protoreflect.MessageKind or protoreflect.GroupKind.
-	//
-	// The returned message must represent an protoreflect.MessageDescriptor
-	// that matches protoreflect.ExtensionDescriptor.MessageType.
-	NewMessage func() protoreflect.ProtoMessage
-
-	// TODO: Separate NewEnum and NewMessage constructors make it possible for
-	// users to provide a constructor that returns a Go type does not match
-	// the corresponding protobuf descriptor in ExtensionDescriptor.
-	// Checking for correctness is hard since descriptors are not comparable.
-	//
-	// An alternative API is for ExtensionDescriptor.{EnumType,MessageType}
-	// to document that it must implement protoreflect.{EnumType,MessageType}.
-
-	once        sync.Once
-	new         func() interface{}
-	goType      reflect.Type
-	valueOf     func(v interface{}) protoreflect.Value
-	interfaceOf func(v protoreflect.Value) interface{}
-}
-type goExtension struct{ *GoExtension }
-
-// NewGoExtension creates a new protoreflect.ExtensionType.
+// GoExtension creates a new protoreflect.ExtensionType.
 //
-// The Go type is currently determined automatically (although custom Go types
-// may be supported in the future). The type is T for scalars and
-// *[]T for vectors. Maps are not valid in extension fields.
+// An enum type must be provided for enum extension fields if
+// ExtensionDescriptor.EnumType does not implement protoreflect.EnumType,
+// in which case it replaces the original enum in ExtensionDescriptor.
+//
+// Similarly, a message type must be provided for message extension fields if
+// ExtensionDescriptor.MessageType does not implement protoreflect.MessageType,
+// in which case it replaces the original message in ExtensionDescriptor.
+//
+// The Go type is currently determined automatically.
+// The type is T for scalars and *[]T for vectors (maps are not allowed).
 // The type T is determined as follows:
 //
 //	+------------+-------------------------------------+
@@ -154,121 +103,151 @@
 // which is often, but not required to be, a named int32 type.
 // The type M is the concrete message type returned by NewMessage,
 // which is often, but not required to be, a pointer to a named struct type.
-//
-// The caller must relinquish full ownership of the input t and must not
-// access or mutate any fields.
-func NewGoExtension(t *GoExtension) protoreflect.ExtensionType {
-	if t.ExtendedType() == nil {
+func GoExtension(xd protoreflect.ExtensionDescriptor, et protoreflect.EnumType, mt protoreflect.MessageType) protoreflect.ExtensionType {
+	if xd.ExtendedType() == nil {
 		panic("field descriptor does not extend a message")
 	}
-	switch t.Kind() {
+	switch xd.Kind() {
 	case protoreflect.EnumKind:
-		if t.NewEnum == nil {
-			panic("enum constructor not provided for enum kind")
+		if et2, ok := xd.EnumType().(protoreflect.EnumType); ok && et == nil {
+			et = et2
 		}
-		if t.NewMessage != nil {
-			panic("message constructor provided for enum kind")
+		if et == nil {
+			panic("enum type not provided for enum kind")
+		}
+		if mt != nil {
+			panic("message type provided for enum kind")
 		}
 	case protoreflect.MessageKind, protoreflect.GroupKind:
-		if t.NewMessage == nil {
-			panic("message constructor not provided for message kind")
+		if mt2, ok := xd.MessageType().(protoreflect.MessageType); ok && mt == nil {
+			mt = mt2
 		}
-		if t.NewEnum != nil {
-			panic("enum constructor provided for message kind")
+		if et != nil {
+			panic("enum type provided for message kind")
+		}
+		if mt == nil {
+			panic("message type not provided for message kind")
 		}
 	default:
-		if t.NewMessage != nil || t.NewEnum != nil {
-			panic(fmt.Sprintf("enum or message constructor provided for %v kind", t.Kind()))
+		if et != nil || mt != nil {
+			panic(fmt.Sprintf("enum or message type provided for %v kind", xd.Kind()))
 		}
 	}
-	return goExtension{t}
+	return &goExtension{ExtensionDescriptor: xd, enumType: et, messageType: mt}
 }
-func (p goExtension) GoNew() interface{} {
-	p.lazyInit()
-	v := p.new()
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+
+type goExtension struct {
+	protoreflect.ExtensionDescriptor
+	enumType    protoreflect.EnumType
+	messageType protoreflect.MessageType
+
+	once        sync.Once
+	typ         reflect.Type
+	new         func() interface{}
+	valueOf     func(v interface{}) protoreflect.Value
+	interfaceOf func(v protoreflect.Value) interface{}
+}
+
+func (t *goExtension) EnumType() protoreflect.EnumDescriptor {
+	return t.enumType
+}
+func (t *goExtension) MessageType() protoreflect.MessageDescriptor {
+	return t.messageType
+}
+func (t *goExtension) GoType() reflect.Type {
+	t.lazyInit()
+	return t.typ
+}
+func (t *goExtension) New() interface{} {
+	t.lazyInit()
+	v := t.new()
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
 	return v
 }
-func (p goExtension) GoType() reflect.Type {
-	p.lazyInit()
-	return p.goType
-}
-func (p goExtension) ValueOf(v interface{}) protoreflect.Value {
-	p.lazyInit()
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+func (t *goExtension) ValueOf(v interface{}) protoreflect.Value {
+	t.lazyInit()
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
-	return p.valueOf(v)
+	return t.valueOf(v)
 }
-func (p goExtension) InterfaceOf(pv protoreflect.Value) interface{} {
-	p.lazyInit()
-	v := p.interfaceOf(pv)
-	if reflect.TypeOf(v) != p.goType {
-		panic(fmt.Sprintf("invalid type: got %T, want %v", v, p.goType))
+func (t *goExtension) InterfaceOf(pv protoreflect.Value) interface{} {
+	t.lazyInit()
+	v := t.interfaceOf(pv)
+	if reflect.TypeOf(v) != t.typ {
+		panic(fmt.Sprintf("invalid type: got %T, want %v", v, t.typ))
 	}
 	return v
 }
-func (p goExtension) lazyInit() {
-	p.once.Do(func() {
-		switch p.Cardinality() {
+func (t *goExtension) lazyInit() {
+	t.once.Do(func() {
+		switch t.Cardinality() {
 		case protoreflect.Optional:
-			switch p.Kind() {
+			switch t.Kind() {
 			case protoreflect.EnumKind:
-				p.goType = reflect.TypeOf(p.NewEnum(0))
-				p.new = func() interface{} { return p.NewEnum(p.Default().Enum()) }
-				p.valueOf = func(v interface{}) protoreflect.Value {
+				t.typ = t.enumType.GoType()
+				t.new = func() interface{} {
+					return t.enumType.New(t.Default().Enum())
+				}
+				t.valueOf = func(v interface{}) protoreflect.Value {
 					ev := v.(protoreflect.ProtoEnum).ProtoReflect()
 					return protoreflect.ValueOf(ev.Number())
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
-					return p.NewEnum(pv.Enum())
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
+					return t.enumType.New(pv.Enum())
 				}
 			case protoreflect.MessageKind, protoreflect.GroupKind:
-				p.goType = reflect.TypeOf(p.NewMessage())
-				p.new = func() interface{} { return p.NewMessage() }
-				p.valueOf = func(v interface{}) protoreflect.Value {
-					return protoreflect.ValueOf(v)
+				t.typ = t.messageType.GoType()
+				t.new = func() interface{} {
+					return t.messageType.New()
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
+				t.valueOf = func(v interface{}) protoreflect.Value {
+					mv := v.(protoreflect.ProtoMessage).ProtoReflect()
+					return protoreflect.ValueOf(mv)
+				}
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
 					return pv.Message().Interface()
 				}
 			default:
-				p.goType = goTypeForPBKind[p.Kind()]
-				p.new = func() interface{} { return p.Default().Interface() }
-				p.valueOf = func(v interface{}) protoreflect.Value {
+				t.typ = goTypeForPBKind[t.Kind()]
+				t.new = func() interface{} {
+					return t.Default().Interface()
+				}
+				t.valueOf = func(v interface{}) protoreflect.Value {
 					return protoreflect.ValueOf(v)
 				}
-				p.interfaceOf = func(pv protoreflect.Value) interface{} {
-					v := pv.Interface()
-					return v
+				t.interfaceOf = func(pv protoreflect.Value) interface{} {
+					return pv.Interface()
 				}
 			}
 		case protoreflect.Repeated:
-			var goType reflect.Type
-			switch p.Kind() {
+			var typ reflect.Type
+			switch t.Kind() {
 			case protoreflect.EnumKind:
-				goType = reflect.TypeOf(p.NewEnum(p.Default().Enum()))
+				typ = t.enumType.GoType()
 			case protoreflect.MessageKind, protoreflect.GroupKind:
-				goType = reflect.TypeOf(p.NewMessage())
+				typ = t.messageType.GoType()
 			default:
-				goType = goTypeForPBKind[p.Kind()]
+				typ = goTypeForPBKind[t.Kind()]
 			}
-			c := value.NewConverter(goType, p.Kind())
-			p.goType = reflect.PtrTo(reflect.SliceOf(goType))
-			p.new = func() interface{} { return reflect.New(p.goType.Elem()).Interface() }
-			p.valueOf = func(v interface{}) protoreflect.Value {
+			c := value.NewConverter(typ, t.Kind())
+			t.typ = reflect.PtrTo(reflect.SliceOf(typ))
+			t.new = func() interface{} {
+				return reflect.New(t.typ.Elem()).Interface()
+			}
+			t.valueOf = func(v interface{}) protoreflect.Value {
 				return protoreflect.ValueOf(value.VectorOf(v, c))
 			}
-			p.interfaceOf = func(v protoreflect.Value) interface{} {
+			t.interfaceOf = func(v protoreflect.Value) interface{} {
 				// TODO: Can we assume that Vector implementations know how
 				// to unwrap themselves?
 				// Should this be part of the public API in protoreflect?
 				return v.Vector().(value.Unwrapper).Unwrap()
 			}
 		default:
-			panic(fmt.Sprintf("invalid cardinality: %v", p.Cardinality()))
+			panic(fmt.Sprintf("invalid cardinality: %v", t.Cardinality()))
 		}
 	})
 }