testing/protocmp: add Message.Unwrap

The Unwrap method returns the original concrete message value.
In theory this allows users to mutate the original message when the
cmp documentation says that all options should be mutation free.
If users want to disregard this documented restriction, they can
already do so in a number of different ways.

Updates #1347

Change-Id: I65225681ab5dbce0763a140fd02666a4ab542a04
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/340489
Trust: Joe Tsai <joetsai@digital-static.net>
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/testing/protocmp/reflect.go b/testing/protocmp/reflect.go
index 5b92cb8..0a5e474 100644
--- a/testing/protocmp/reflect.go
+++ b/testing/protocmp/reflect.go
@@ -68,7 +68,7 @@
 	}
 
 	// Range over populated extension fields.
-	for _, xd := range m[messageTypeKey].(messageType).xds {
+	for _, xd := range m[messageTypeKey].(messageMeta).xds {
 		if m.Has(xd) && !f(xd, m.Get(xd)) {
 			return
 		}
@@ -91,7 +91,7 @@
 			return protoreflect.ValueOfMap(reflectMap{})
 		case fd.Message() != nil:
 			return protoreflect.ValueOfMessage(reflectMessage{
-				messageTypeKey: messageType{md: m.Descriptor()},
+				messageTypeKey: messageMeta{md: fd.Message()},
 			})
 		default:
 			return fd.Default()
diff --git a/testing/protocmp/util.go b/testing/protocmp/util.go
index ac6237e..79f3072 100644
--- a/testing/protocmp/util.go
+++ b/testing/protocmp/util.go
@@ -297,11 +297,11 @@
 		return true // treat missing fields as already filtered
 	}
 	var fd protoreflect.FieldDescriptor
-	switch mt := m[messageTypeKey].(messageType); {
+	switch mm := m[messageTypeKey].(messageMeta); {
 	case protoreflect.Name(k).IsValid():
-		fd = mt.md.Fields().ByTextName(k)
+		fd = mm.md.Fields().ByTextName(k)
 	default:
-		fd = mt.xds[k]
+		fd = mm.xds[k]
 	}
 	if fd != nil {
 		return f.names[fd.FullName()]
@@ -376,11 +376,11 @@
 	}
 
 	var fd protoreflect.FieldDescriptor
-	switch mt := m[messageTypeKey].(messageType); {
+	switch mm := m[messageTypeKey].(messageMeta); {
 	case protoreflect.Name(k).IsValid():
-		fd = mt.md.Fields().ByTextName(k)
+		fd = mm.md.Fields().ByTextName(k)
 	default:
-		fd = mt.xds[k]
+		fd = mm.xds[k]
 	}
 	if fd == nil || !fd.Default().IsValid() {
 		return false
diff --git a/testing/protocmp/xform.go b/testing/protocmp/xform.go
index 5a47d0f..7a32e2d 100644
--- a/testing/protocmp/xform.go
+++ b/testing/protocmp/xform.go
@@ -68,20 +68,28 @@
 }
 
 const (
-	messageTypeKey    = "@type"
+	// messageTypeKey indicates the protobuf message type.
+	// The value type is always messageMeta.
+	// From the public API, it presents itself as only the type, but the
+	// underlying data structure holds arbitrary metadata about the message.
+	messageTypeKey = "@type"
+
+	// messageInvalidKey indicates that the message is invalid.
+	// The value is always the boolean "true".
 	messageInvalidKey = "@invalid"
 )
 
-type messageType struct {
+type messageMeta struct {
+	m   proto.Message
 	md  protoreflect.MessageDescriptor
 	xds map[string]protoreflect.ExtensionDescriptor
 }
 
-func (t messageType) String() string {
+func (t messageMeta) String() string {
 	return string(t.md.FullName())
 }
 
-func (t1 messageType) Equal(t2 messageType) bool {
+func (t1 messageMeta) Equal(t2 messageMeta) bool {
 	return t1.md.FullName() == t2.md.FullName()
 }
 
@@ -109,11 +117,18 @@
 // Message values must not be created by or mutated by users.
 type Message map[string]interface{}
 
+// Unwrap returns the original message value.
+// It returns nil if this Message was not constructed from another message.
+func (m Message) Unwrap() proto.Message {
+	mm, _ := m[messageTypeKey].(messageMeta)
+	return mm.m
+}
+
 // Descriptor return the message descriptor.
 // It returns nil for a zero Message value.
 func (m Message) Descriptor() protoreflect.MessageDescriptor {
-	mt, _ := m[messageTypeKey].(messageType)
-	return mt.md
+	mm, _ := m[messageTypeKey].(messageMeta)
+	return mm.md
 }
 
 // ProtoReflect returns a reflective view of m.
@@ -201,7 +216,7 @@
 		case m == nil:
 			return nil
 		case !m.IsValid():
-			return Message{messageTypeKey: messageType{md: m.Descriptor()}, messageInvalidKey: true}
+			return Message{messageTypeKey: messageMeta{m: m.Interface(), md: m.Descriptor()}, messageInvalidKey: true}
 		default:
 			return transformMessage(m)
 		}
@@ -218,7 +233,7 @@
 
 func transformMessage(m protoreflect.Message) Message {
 	mx := Message{}
-	mt := messageType{md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
+	mt := messageMeta{m: m.Interface(), md: m.Descriptor(), xds: make(map[string]protoreflect.FieldDescriptor)}
 
 	// Handle known and extension fields.
 	m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
diff --git a/testing/protocmp/xform_test.go b/testing/protocmp/xform_test.go
index c6355d7..a0f7e24 100644
--- a/testing/protocmp/xform_test.go
+++ b/testing/protocmp/xform_test.go
@@ -40,7 +40,7 @@
 			OptionalNestedMessage: &testpb.TestAllTypes_NestedMessage{A: proto.Int32(5)},
 		},
 		want: Message{
-			messageTypeKey:            messageTypeOf(&testpb.TestAllTypes{}),
+			messageTypeKey:            messageMetaOf(&testpb.TestAllTypes{}),
 			"optional_bool":           bool(false),
 			"optional_int32":          int32(-32),
 			"optional_int64":          int64(-64),
@@ -51,7 +51,7 @@
 			"optional_string":         string("string"),
 			"optional_bytes":          []byte("bytes"),
 			"optional_nested_enum":    enumOf(testpb.TestAllTypes_NEG),
-			"optional_nested_message": Message{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
+			"optional_nested_message": Message{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
 		},
 	}, {
 		in: &testpb.TestAllTypes{
@@ -74,7 +74,7 @@
 			},
 		},
 		want: Message{
-			messageTypeKey:    messageTypeOf(&testpb.TestAllTypes{}),
+			messageTypeKey:    messageMetaOf(&testpb.TestAllTypes{}),
 			"repeated_bool":   []bool{false, true},
 			"repeated_int32":  []int32{32, -32},
 			"repeated_int64":  []int64{64, -64},
@@ -89,8 +89,8 @@
 				enumOf(testpb.TestAllTypes_BAR),
 			},
 			"repeated_nested_message": []Message{
-				{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
-				{messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
+				{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
+				{messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(-5)},
 			},
 		},
 	}, {
@@ -112,7 +112,7 @@
 			},
 		},
 		want: Message{
-			messageTypeKey:      messageTypeOf(&testpb.TestAllTypes{}),
+			messageTypeKey:      messageMetaOf(&testpb.TestAllTypes{}),
 			"map_bool_bool":     map[bool]bool{true: false},
 			"map_int32_int32":   map[int32]int32{-32: 32},
 			"map_int64_int64":   map[int64]int64{-64: 64},
@@ -126,7 +126,7 @@
 				"k": enumOf(testpb.TestAllTypes_FOO),
 			},
 			"map_string_nested_message": map[string]Message{
-				"k": {messageTypeKey: messageTypeOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
+				"k": {messageTypeKey: messageMetaOf(&testpb.TestAllTypes_NestedMessage{}), "a": int32(5)},
 			},
 		},
 	}, {
@@ -146,7 +146,7 @@
 			return m
 		}(),
 		want: Message{
-			messageTypeKey:                                 messageTypeOf(&testpb.TestAllExtensions{}),
+			messageTypeKey:                                 messageMetaOf(&testpb.TestAllExtensions{}),
 			"[goproto.proto.test.optional_bool]":           bool(false),
 			"[goproto.proto.test.optional_int32]":          int32(-32),
 			"[goproto.proto.test.optional_int64]":          int64(-64),
@@ -157,7 +157,7 @@
 			"[goproto.proto.test.optional_string]":         string("string"),
 			"[goproto.proto.test.optional_bytes]":          []byte("bytes"),
 			"[goproto.proto.test.optional_nested_enum]":    enumOf(testpb.TestAllTypes_NEG),
-			"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
+			"[goproto.proto.test.optional_nested_message]": Message{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
 		},
 	}, {
 		in: func() proto.Message {
@@ -182,7 +182,7 @@
 			return m
 		}(),
 		want: Message{
-			messageTypeKey:                         messageTypeOf(&testpb.TestAllExtensions{}),
+			messageTypeKey:                         messageMetaOf(&testpb.TestAllExtensions{}),
 			"[goproto.proto.test.repeated_bool]":   []bool{false, true},
 			"[goproto.proto.test.repeated_int32]":  []int32{32, -32},
 			"[goproto.proto.test.repeated_int64]":  []int64{64, -64},
@@ -197,8 +197,8 @@
 				enumOf(testpb.TestAllTypes_BAR),
 			},
 			"[goproto.proto.test.repeated_nested_message]": []Message{
-				{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
-				{messageTypeKey: messageTypeOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
+				{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(5)},
+				{messageTypeKey: messageMetaOf(&testpb.TestAllExtensions_NestedMessage{}), "a": int32(-5)},
 			},
 		},
 	}, {
@@ -229,7 +229,7 @@
 			return m
 		}(),
 		want: Message{
-			messageTypeKey: messageTypeOf(&testpb.TestAllTypes{}),
+			messageTypeKey: messageMetaOf(&testpb.TestAllTypes{}),
 			"50000":        protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50000, Type: protopack.VarintType}, protopack.Uvarint(100)}.Marshal()),
 			"50001":        protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50001, Type: protopack.Fixed32Type}, protopack.Uint32(200)}.Marshal()),
 			"50002":        protoreflect.RawFields(protopack.Message{protopack.Tag{Number: 50002, Type: protopack.Fixed64Type}, protopack.Uint64(300)}.Marshal()),
@@ -258,6 +258,9 @@
 			if diff := cmp.Diff(tt.want, got); diff != "" {
 				t.Errorf("Transform() mismatch (-want +got):\n%v", diff)
 			}
+			if got.Unwrap() != tt.in {
+				t.Errorf("got.Unwrap() = %p, want %p", got.Unwrap(), tt.in)
+			}
 		})
 	}
 }
@@ -266,6 +269,6 @@
 	return Enum{e.Number(), e.Descriptor()}
 }
 
-func messageTypeOf(m protoreflect.ProtoMessage) messageType {
-	return messageType{md: m.ProtoReflect().Descriptor()}
+func messageMetaOf(m protoreflect.ProtoMessage) messageMeta {
+	return messageMeta{m: m, md: m.ProtoReflect().Descriptor()}
 }