internal/detectknown: add helper package to identify well-known types

Change-Id: Id54621b4b44522a350e6994074962852690b5d66
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/225257
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go
index e5e9b2e..3c3ef14 100644
--- a/encoding/protojson/well_known_types.go
+++ b/encoding/protojson/well_known_types.go
@@ -11,6 +11,7 @@
 	"strings"
 	"time"
 
+	"google.golang.org/protobuf/internal/detectknown"
 	"google.golang.org/protobuf/internal/encoding/json"
 	"google.golang.org/protobuf/internal/errors"
 	"google.golang.org/protobuf/internal/fieldnum"
@@ -23,27 +24,18 @@
 // The list of custom types here has to match the ones in marshalCustomType and
 // unmarshalCustomType.
 func isCustomType(name pref.FullName) bool {
-	switch name {
-	case "google.protobuf.Any",
-		"google.protobuf.BoolValue",
-		"google.protobuf.DoubleValue",
-		"google.protobuf.FloatValue",
-		"google.protobuf.Int32Value",
-		"google.protobuf.Int64Value",
-		"google.protobuf.UInt32Value",
-		"google.protobuf.UInt64Value",
-		"google.protobuf.StringValue",
-		"google.protobuf.BytesValue",
-		"google.protobuf.Empty",
-		"google.protobuf.Struct",
-		"google.protobuf.ListValue",
-		"google.protobuf.Value",
-		"google.protobuf.Duration",
-		"google.protobuf.Timestamp",
-		"google.protobuf.FieldMask":
-		return true
+	switch detectknown.Which(name) {
+	case detectknown.AnyProto:
+	case detectknown.TimestampProto:
+	case detectknown.DurationProto:
+	case detectknown.WrappersProto:
+	case detectknown.StructProto:
+	case detectknown.FieldMaskProto:
+	case detectknown.EmptyProto:
+	default:
+		return false
 	}
-	return false
+	return true
 }
 
 // marshalCustomType marshals given well-known type message that have special
@@ -51,44 +43,24 @@
 // returns true, else it will panic.
 func (e encoder) marshalCustomType(m pref.Message) error {
 	name := m.Descriptor().FullName()
-	switch name {
-	case "google.protobuf.Any":
+	switch detectknown.Which(name) {
+	case detectknown.AnyProto:
 		return e.marshalAny(m)
-
-	case "google.protobuf.BoolValue",
-		"google.protobuf.DoubleValue",
-		"google.protobuf.FloatValue",
-		"google.protobuf.Int32Value",
-		"google.protobuf.Int64Value",
-		"google.protobuf.UInt32Value",
-		"google.protobuf.UInt64Value",
-		"google.protobuf.StringValue",
-		"google.protobuf.BytesValue":
-		return e.marshalWrapperType(m)
-
-	case "google.protobuf.Empty":
-		return e.marshalEmpty(m)
-
-	case "google.protobuf.Struct":
-		return e.marshalStruct(m)
-
-	case "google.protobuf.ListValue":
-		return e.marshalListValue(m)
-
-	case "google.protobuf.Value":
-		return e.marshalKnownValue(m)
-
-	case "google.protobuf.Duration":
-		return e.marshalDuration(m)
-
-	case "google.protobuf.Timestamp":
+	case detectknown.TimestampProto:
 		return e.marshalTimestamp(m)
-
-	case "google.protobuf.FieldMask":
+	case detectknown.DurationProto:
+		return e.marshalDuration(m)
+	case detectknown.WrappersProto:
+		return e.marshalWrapperType(m)
+	case detectknown.StructProto:
+		return e.marshalStructType(m)
+	case detectknown.FieldMaskProto:
 		return e.marshalFieldMask(m)
+	case detectknown.EmptyProto:
+		return e.marshalEmpty(m)
+	default:
+		panic(fmt.Sprintf("%s does not have a custom marshaler", name))
 	}
-
-	panic(fmt.Sprintf("%s does not have a custom marshaler", name))
 }
 
 // unmarshalCustomType unmarshals given well-known type message that have
@@ -96,44 +68,24 @@
 // isCustomType returns true, else it will panic.
 func (d decoder) unmarshalCustomType(m pref.Message) error {
 	name := m.Descriptor().FullName()
-	switch name {
-	case "google.protobuf.Any":
+	switch detectknown.Which(name) {
+	case detectknown.AnyProto:
 		return d.unmarshalAny(m)
-
-	case "google.protobuf.BoolValue",
-		"google.protobuf.DoubleValue",
-		"google.protobuf.FloatValue",
-		"google.protobuf.Int32Value",
-		"google.protobuf.Int64Value",
-		"google.protobuf.UInt32Value",
-		"google.protobuf.UInt64Value",
-		"google.protobuf.StringValue",
-		"google.protobuf.BytesValue":
-		return d.unmarshalWrapperType(m)
-
-	case "google.protobuf.Empty":
-		return d.unmarshalEmpty(m)
-
-	case "google.protobuf.Struct":
-		return d.unmarshalStruct(m)
-
-	case "google.protobuf.ListValue":
-		return d.unmarshalListValue(m)
-
-	case "google.protobuf.Value":
-		return d.unmarshalKnownValue(m)
-
-	case "google.protobuf.Duration":
-		return d.unmarshalDuration(m)
-
-	case "google.protobuf.Timestamp":
+	case detectknown.TimestampProto:
 		return d.unmarshalTimestamp(m)
-
-	case "google.protobuf.FieldMask":
+	case detectknown.DurationProto:
+		return d.unmarshalDuration(m)
+	case detectknown.WrappersProto:
+		return d.unmarshalWrapperType(m)
+	case detectknown.StructProto:
+		return d.unmarshalStructType(m)
+	case detectknown.FieldMaskProto:
 		return d.unmarshalFieldMask(m)
+	case detectknown.EmptyProto:
+		return d.unmarshalEmpty(m)
+	default:
+		panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
 	}
-
-	panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
 }
 
 // The JSON representation of an Any message uses the regular representation of
@@ -501,6 +453,32 @@
 	}
 }
 
+func (e encoder) marshalStructType(m pref.Message) error {
+	switch m.Descriptor().Name() {
+	case "Struct":
+		return e.marshalStruct(m)
+	case "ListValue":
+		return e.marshalListValue(m)
+	case "Value":
+		return e.marshalKnownValue(m)
+	default:
+		panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
+	}
+}
+
+func (d decoder) unmarshalStructType(m pref.Message) error {
+	switch m.Descriptor().Name() {
+	case "Struct":
+		return d.unmarshalStruct(m)
+	case "ListValue":
+		return d.unmarshalListValue(m)
+	case "Value":
+		return d.unmarshalKnownValue(m)
+	default:
+		panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
+	}
+}
+
 // The JSON representation for Struct is a JSON object that contains the encoded
 // Struct.fields map and follows the serialization rules for a map.
 
diff --git a/internal/detectknown/detect.go b/internal/detectknown/detect.go
new file mode 100644
index 0000000..091c423
--- /dev/null
+++ b/internal/detectknown/detect.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package detectknown provides functionality for detecting well-known types
+// and identifying them by name.
+package detectknown
+
+import "google.golang.org/protobuf/reflect/protoreflect"
+
+type ProtoFile int
+
+const (
+	Unknown ProtoFile = iota
+	AnyProto
+	TimestampProto
+	DurationProto
+	WrappersProto
+	StructProto
+	FieldMaskProto
+	EmptyProto
+)
+
+var wellKnownTypes = map[protoreflect.FullName]ProtoFile{
+	"google.protobuf.Any":         AnyProto,
+	"google.protobuf.Timestamp":   TimestampProto,
+	"google.protobuf.Duration":    DurationProto,
+	"google.protobuf.BoolValue":   WrappersProto,
+	"google.protobuf.Int32Value":  WrappersProto,
+	"google.protobuf.Int64Value":  WrappersProto,
+	"google.protobuf.UInt32Value": WrappersProto,
+	"google.protobuf.UInt64Value": WrappersProto,
+	"google.protobuf.FloatValue":  WrappersProto,
+	"google.protobuf.DoubleValue": WrappersProto,
+	"google.protobuf.BytesValue":  WrappersProto,
+	"google.protobuf.StringValue": WrappersProto,
+	"google.protobuf.Struct":      StructProto,
+	"google.protobuf.ListValue":   StructProto,
+	"google.protobuf.Value":       StructProto,
+	"google.protobuf.FieldMask":   FieldMaskProto,
+	"google.protobuf.Empty":       EmptyProto,
+}
+
+// Which identifies the proto file that a well-known type belongs to.
+func Which(s protoreflect.FullName) ProtoFile {
+	return wellKnownTypes[s]
+}
diff --git a/internal/detectknown/detect_test.go b/internal/detectknown/detect_test.go
new file mode 100644
index 0000000..c9a31c9
--- /dev/null
+++ b/internal/detectknown/detect_test.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package detectknown_test
+
+import (
+	"testing"
+
+	"google.golang.org/protobuf/internal/detectknown"
+	"google.golang.org/protobuf/reflect/protoreflect"
+
+	fieldmaskpb "google.golang.org/protobuf/internal/testprotos/fieldmaskpb"
+	"google.golang.org/protobuf/types/descriptorpb"
+	"google.golang.org/protobuf/types/known/anypb"
+	"google.golang.org/protobuf/types/known/durationpb"
+	"google.golang.org/protobuf/types/known/emptypb"
+	"google.golang.org/protobuf/types/known/structpb"
+	"google.golang.org/protobuf/types/known/timestamppb"
+	"google.golang.org/protobuf/types/known/wrapperspb"
+	"google.golang.org/protobuf/types/pluginpb"
+)
+
+func TestWhich(t *testing.T) {
+	tests := []struct {
+		in   protoreflect.FileDescriptor
+		want detectknown.ProtoFile
+	}{
+		{descriptorpb.File_google_protobuf_descriptor_proto, detectknown.Unknown},
+		{pluginpb.File_google_protobuf_compiler_plugin_proto, detectknown.Unknown},
+		{anypb.File_google_protobuf_any_proto, detectknown.AnyProto},
+		{timestamppb.File_google_protobuf_timestamp_proto, detectknown.TimestampProto},
+		{durationpb.File_google_protobuf_duration_proto, detectknown.DurationProto},
+		{wrapperspb.File_google_protobuf_wrappers_proto, detectknown.WrappersProto},
+		{structpb.File_google_protobuf_struct_proto, detectknown.StructProto},
+		{fieldmaskpb.File_google_protobuf_field_mask_proto, detectknown.FieldMaskProto},
+		{emptypb.File_google_protobuf_empty_proto, detectknown.EmptyProto},
+	}
+
+	for _, tt := range tests {
+		rangeMessages(tt.in.Messages(), func(md protoreflect.MessageDescriptor) {
+			got := detectknown.Which(md.FullName())
+			if got != tt.want {
+				t.Errorf("Which(%s) = %v, want %v", md.FullName(), got, tt.want)
+			}
+		})
+	}
+}
+
+func rangeMessages(mds protoreflect.MessageDescriptors, f func(protoreflect.MessageDescriptor)) {
+	for i := 0; i < mds.Len(); i++ {
+		md := mds.Get(i)
+		if !md.IsMapEntry() {
+			f(md)
+		}
+		rangeMessages(md.Messages(), f)
+	}
+}
diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go
index 21023e5..c2c856f 100644
--- a/internal/msgfmt/format.go
+++ b/internal/msgfmt/format.go
@@ -18,6 +18,7 @@
 	"time"
 
 	"google.golang.org/protobuf/encoding/protowire"
+	"google.golang.org/protobuf/internal/detectknown"
 	"google.golang.org/protobuf/internal/detrand"
 	"google.golang.org/protobuf/internal/mapsort"
 	"google.golang.org/protobuf/proto"
@@ -102,13 +103,9 @@
 
 func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
 	md := m.Descriptor()
-	if md.FullName().Parent() != "google.protobuf" {
-		return nil
-	}
-
 	fds := md.Fields()
-	switch md.Name() {
-	case "Any":
+	switch detectknown.Which(md.FullName()) {
+	case detectknown.AnyProto:
 		var msgVal protoreflect.Message
 		url := m.Get(fds.ByName("type_url")).String()
 		if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) {
@@ -140,7 +137,7 @@
 		b = append(b, '}')
 		return b
 
-	case "Timestamp":
+	case detectknown.TimestampProto:
 		secs := m.Get(fds.ByName("seconds")).Int()
 		nanos := m.Get(fds.ByName("nanos")).Int()
 		if nanos < 0 || nanos >= 1e9 {
@@ -153,7 +150,7 @@
 		x = strings.TrimSuffix(x, ".000")
 		return append(b, x+"Z"...)
 
-	case "Duration":
+	case detectknown.DurationProto:
 		secs := m.Get(fds.ByName("seconds")).Int()
 		nanos := m.Get(fds.ByName("nanos")).Int()
 		if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
@@ -165,7 +162,7 @@
 		x = strings.TrimSuffix(x, ".000")
 		return append(b, x+"s"...)
 
-	case "BoolValue", "Int32Value", "Int64Value", "UInt32Value", "UInt64Value", "FloatValue", "DoubleValue", "StringValue", "BytesValue":
+	case detectknown.WrappersProto:
 		fd := fds.ByName("value")
 		return appendValue(b, m.Get(fd), fd)
 	}