internal/impl: support legacy message names

This change:
* Adds aberrant support for the undocumented XXX_MessageName method.
* Adds LegacyMessageTypeOf so that v1 registration can suggest a
fullname to use with a legacy message with no Descriptor support.

Change-Id: I0265bd3cf67f4d4815358148f5817695c1122dc8
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193518
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/api_export.go b/internal/impl/api_export.go
index f2b3738..52403cf 100644
--- a/internal/impl/api_export.go
+++ b/internal/impl/api_export.go
@@ -11,7 +11,6 @@
 
 	"google.golang.org/protobuf/encoding/prototext"
 	"google.golang.org/protobuf/proto"
-	"google.golang.org/protobuf/reflect/protoreflect"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
@@ -63,7 +62,7 @@
 type message = interface{}
 
 // legacyMessageWrapper wraps a v2 message as a v1 message.
-type legacyMessageWrapper struct{ m protoreflect.ProtoMessage }
+type legacyMessageWrapper struct{ m pref.ProtoMessage }
 
 func (m legacyMessageWrapper) Reset()         { proto.Reset(m.m) }
 func (m legacyMessageWrapper) String() string { return Export{}.MessageStringOf(m.m) }
@@ -76,67 +75,56 @@
 		return mv
 	case unwrapper:
 		return Export{}.ProtoMessageV1Of(mv.protoUnwrap())
-	case protoreflect.ProtoMessage:
+	case pref.ProtoMessage:
 		return legacyMessageWrapper{mv}
 	default:
 		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
 	}
 }
 
+func (Export) protoMessageV2Of(m message) pref.ProtoMessage {
+	switch mv := m.(type) {
+	case pref.ProtoMessage:
+		return mv
+	case legacyMessageWrapper:
+		return mv.m
+	case piface.MessageV1:
+		return nil
+	default:
+		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
+	}
+}
+
 // ProtoMessageV2Of converts either a v1 or v2 message to a v2 message.
 func (Export) ProtoMessageV2Of(m message) pref.ProtoMessage {
-	switch mv := m.(type) {
-	case protoreflect.ProtoMessage:
+	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv
-	case legacyMessageWrapper:
-		return mv.m
-	case piface.MessageV1:
-		return legacyWrapMessage(reflect.ValueOf(mv))
-	default:
-		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
 	}
+	return legacyWrapMessage(reflect.ValueOf(m))
 }
 
 // MessageOf returns the protoreflect.Message interface over m.
 func (Export) MessageOf(m message) pref.Message {
-	switch mv := m.(type) {
-	case pref.ProtoMessage:
+	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv.ProtoReflect()
-	case legacyMessageWrapper:
-		return mv.m.ProtoReflect()
-	case piface.MessageV1:
-		return legacyWrapMessage(reflect.ValueOf(mv)).ProtoReflect()
-	default:
-		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
 	}
+	return legacyWrapMessage(reflect.ValueOf(m)).ProtoReflect()
 }
 
 // MessageDescriptorOf returns the protoreflect.MessageDescriptor for m.
 func (Export) MessageDescriptorOf(m message) pref.MessageDescriptor {
-	switch mv := m.(type) {
-	case pref.ProtoMessage:
+	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv.ProtoReflect().Descriptor()
-	case legacyMessageWrapper:
-		return mv.m.ProtoReflect().Descriptor()
-	case piface.MessageV1:
-		return LegacyLoadMessageDesc(reflect.TypeOf(mv))
-	default:
-		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
 	}
+	return LegacyLoadMessageDesc(reflect.TypeOf(m))
 }
 
 // MessageTypeOf returns the protoreflect.MessageType for m.
 func (Export) MessageTypeOf(m message) pref.MessageType {
-	switch mv := m.(type) {
-	case pref.ProtoMessage:
+	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
 		return mv.ProtoReflect().Type()
-	case legacyMessageWrapper:
-		return mv.m.ProtoReflect().Type()
-	case piface.MessageV1:
-		return legacyLoadMessageInfo(reflect.TypeOf(mv))
-	default:
-		panic(fmt.Sprintf("message %T is neither a v1 or v2 Message", m))
 	}
+	return legacyLoadMessageInfo(reflect.TypeOf(m), "")
 }
 
 // MessageStringOf returns the message value as a string,
diff --git a/internal/impl/legacy_export.go b/internal/impl/legacy_export.go
index 29c1b01..989e944 100644
--- a/internal/impl/legacy_export.go
+++ b/internal/impl/legacy_export.go
@@ -26,6 +26,15 @@
 	return legacyEnumName(ed)
 }
 
+// LegacyMessageTypeOf returns the protoreflect.MessageType for m,
+// with name used as the message name if necessary.
+func (Export) LegacyMessageTypeOf(m piface.MessageV1, name pref.FullName) pref.MessageType {
+	if mv := (Export{}).protoMessageV2Of(m); mv != nil {
+		return mv.ProtoReflect().Type()
+	}
+	return legacyLoadMessageInfo(reflect.TypeOf(m), name)
+}
+
 // UnmarshalJSONEnum unmarshals an enum from a JSON-encoded input.
 // The input can either be a string representing the enum value by name,
 // or a number representing the enum number itself.
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index c70e30f..ff1c69b 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -21,7 +21,7 @@
 // legacyWrapMessage wraps v as a protoreflect.ProtoMessage,
 // where v must be a *struct kind and not implement the v2 API already.
 func legacyWrapMessage(v reflect.Value) pref.ProtoMessage {
-	mt := legacyLoadMessageInfo(v.Type())
+	mt := legacyLoadMessageInfo(v.Type(), "")
 	return mt.MessageOf(v.Interface()).Interface()
 }
 
@@ -29,7 +29,8 @@
 
 // legacyLoadMessageInfo dynamically loads a *MessageInfo for t,
 // where t must be a *struct kind and not implement the v2 API already.
-func legacyLoadMessageInfo(t reflect.Type) *MessageInfo {
+// The provided name is used if it cannot be determined from the message.
+func legacyLoadMessageInfo(t reflect.Type, name pref.FullName) *MessageInfo {
 	// Fast-path: check if a MessageInfo is cached for this concrete type.
 	if mt, ok := legacyMessageTypeCache.Load(t); ok {
 		return mt.(*MessageInfo)
@@ -37,7 +38,7 @@
 
 	// Slow-path: derive message descriptor and initialize MessageInfo.
 	mi := &MessageInfo{
-		Desc:          LegacyLoadMessageDesc(t),
+		Desc:          legacyLoadMessageDesc(t, name),
 		GoReflectType: t,
 	}
 	if mi, ok := legacyMessageTypeCache.LoadOrStore(t, mi); ok {
@@ -53,6 +54,9 @@
 //
 // This is exported for testing purposes.
 func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
+	return legacyLoadMessageDesc(t, "")
+}
+func legacyLoadMessageDesc(t reflect.Type, name pref.FullName) pref.MessageDescriptor {
 	// Fast-path: check if a MessageDescriptor is cached for this concrete type.
 	if mi, ok := legacyMessageDescCache.Load(t); ok {
 		return mi.(pref.MessageDescriptor)
@@ -65,7 +69,7 @@
 	}
 	mdV1, ok := mv.(messageV1)
 	if !ok {
-		return aberrantLoadMessageDesc(t)
+		return aberrantLoadMessageDesc(t, name)
 	}
 	b, idxs := mdV1.Descriptor()
 
@@ -73,6 +77,9 @@
 	for _, i := range idxs[1:] {
 		md = md.Messages().Get(i)
 	}
+	if name != "" && md.FullName() != name {
+		panic(fmt.Sprintf("mismatching message name: got %v, want %v", md.FullName(), name))
+	}
 	if md, ok := legacyMessageDescCache.LoadOrStore(t, md); ok {
 		return md.(protoreflect.MessageDescriptor)
 	}
@@ -89,15 +96,15 @@
 //
 // This is a best-effort derivation of the message descriptor using the protobuf
 // tags on the struct fields.
-func aberrantLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
+func aberrantLoadMessageDesc(t reflect.Type, name pref.FullName) pref.MessageDescriptor {
 	aberrantMessageDescLock.Lock()
 	defer aberrantMessageDescLock.Unlock()
 	if aberrantMessageDescCache == nil {
 		aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
 	}
-	return aberrantLoadMessageDescReentrant(t)
+	return aberrantLoadMessageDescReentrant(t, name)
 }
-func aberrantLoadMessageDescReentrant(t reflect.Type) pref.MessageDescriptor {
+func aberrantLoadMessageDescReentrant(t reflect.Type, name pref.FullName) pref.MessageDescriptor {
 	// Fast-path: check if an MessageDescriptor is cached for this concrete type.
 	if md, ok := aberrantMessageDescCache[t]; ok {
 		return md
@@ -107,7 +114,7 @@
 	// Cache the MessageDescriptor early on so that we can resolve internal
 	// cyclic references.
 	md := &filedesc.Message{L2: new(filedesc.MessageL2)}
-	md.L0.FullName = aberrantDeriveFullName(t.Elem())
+	md.L0.FullName = aberrantDeriveMessageName(t.Elem(), name)
 	md.L0.ParentFile = filedesc.SurrogateProto2
 	aberrantMessageDescCache[t] = md
 
@@ -191,6 +198,18 @@
 	return md
 }
 
+func aberrantDeriveMessageName(t reflect.Type, name pref.FullName) pref.FullName {
+	if name.IsValid() {
+		return name
+	}
+	if m, ok := reflect.New(t).Interface().(interface{ XXX_MessageName() string }); ok {
+		if name := pref.FullName(m.XXX_MessageName()); name.IsValid() {
+			return name
+		}
+	}
+	return aberrantDeriveFullName(t)
+}
+
 func aberrantAppendField(md *filedesc.Message, goType reflect.Type, tag, tagKey, tagVal string) {
 	t := goType
 	isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
@@ -260,7 +279,7 @@
 				fd.L1.Message = md2
 				break
 			}
-			fd.L1.Message = aberrantLoadMessageDescReentrant(t)
+			fd.L1.Message = aberrantLoadMessageDescReentrant(t, "")
 		}
 	}
 }
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index a231fec..6a11f40 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -664,3 +664,36 @@
 		}
 	}
 }
+
+type LegacyTestMessageName1 struct{}
+
+func (*LegacyTestMessageName1) Reset()         { panic("not implemented") }
+func (*LegacyTestMessageName1) String() string { panic("not implemented") }
+func (*LegacyTestMessageName1) ProtoMessage()  { panic("not implemented") }
+
+type LegacyTestMessageName2 struct{}
+
+func (*LegacyTestMessageName2) Reset()         { panic("not implemented") }
+func (*LegacyTestMessageName2) String() string { panic("not implemented") }
+func (*LegacyTestMessageName2) ProtoMessage()  { panic("not implemented") }
+func (*LegacyTestMessageName2) XXX_MessageName() string {
+	return "google.golang.org.LegacyTestMessageName2"
+}
+
+func TestLegacyMessageName(t *testing.T) {
+	tests := []struct {
+		in          piface.MessageV1
+		suggestName pref.FullName
+		wantName    pref.FullName
+	}{
+		{new(LegacyTestMessageName1), "google.golang.org.LegacyTestMessageName1", "google.golang.org.LegacyTestMessageName1"},
+		{new(LegacyTestMessageName2), "", "google.golang.org.LegacyTestMessageName2"},
+	}
+
+	for _, tt := range tests {
+		mt := pimpl.Export{}.LegacyMessageTypeOf(tt.in, tt.suggestName)
+		if got := mt.Descriptor().FullName(); got != tt.wantName {
+			t.Errorf("type: %T, name mismatch: got %v, want %v", tt.in, got, tt.wantName)
+		}
+	}
+}