internal/detectknown: add helper package to identify well-known types
Change-Id: Id54621b4b44522a350e6994074962852690b5d66
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/225257
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go
index e5e9b2e..3c3ef14 100644
--- a/encoding/protojson/well_known_types.go
+++ b/encoding/protojson/well_known_types.go
@@ -11,6 +11,7 @@
"strings"
"time"
+ "google.golang.org/protobuf/internal/detectknown"
"google.golang.org/protobuf/internal/encoding/json"
"google.golang.org/protobuf/internal/errors"
"google.golang.org/protobuf/internal/fieldnum"
@@ -23,27 +24,18 @@
// The list of custom types here has to match the ones in marshalCustomType and
// unmarshalCustomType.
func isCustomType(name pref.FullName) bool {
- switch name {
- case "google.protobuf.Any",
- "google.protobuf.BoolValue",
- "google.protobuf.DoubleValue",
- "google.protobuf.FloatValue",
- "google.protobuf.Int32Value",
- "google.protobuf.Int64Value",
- "google.protobuf.UInt32Value",
- "google.protobuf.UInt64Value",
- "google.protobuf.StringValue",
- "google.protobuf.BytesValue",
- "google.protobuf.Empty",
- "google.protobuf.Struct",
- "google.protobuf.ListValue",
- "google.protobuf.Value",
- "google.protobuf.Duration",
- "google.protobuf.Timestamp",
- "google.protobuf.FieldMask":
- return true
+ switch detectknown.Which(name) {
+ case detectknown.AnyProto:
+ case detectknown.TimestampProto:
+ case detectknown.DurationProto:
+ case detectknown.WrappersProto:
+ case detectknown.StructProto:
+ case detectknown.FieldMaskProto:
+ case detectknown.EmptyProto:
+ default:
+ return false
}
- return false
+ return true
}
// marshalCustomType marshals given well-known type message that have special
@@ -51,44 +43,24 @@
// returns true, else it will panic.
func (e encoder) marshalCustomType(m pref.Message) error {
name := m.Descriptor().FullName()
- switch name {
- case "google.protobuf.Any":
+ switch detectknown.Which(name) {
+ case detectknown.AnyProto:
return e.marshalAny(m)
-
- case "google.protobuf.BoolValue",
- "google.protobuf.DoubleValue",
- "google.protobuf.FloatValue",
- "google.protobuf.Int32Value",
- "google.protobuf.Int64Value",
- "google.protobuf.UInt32Value",
- "google.protobuf.UInt64Value",
- "google.protobuf.StringValue",
- "google.protobuf.BytesValue":
- return e.marshalWrapperType(m)
-
- case "google.protobuf.Empty":
- return e.marshalEmpty(m)
-
- case "google.protobuf.Struct":
- return e.marshalStruct(m)
-
- case "google.protobuf.ListValue":
- return e.marshalListValue(m)
-
- case "google.protobuf.Value":
- return e.marshalKnownValue(m)
-
- case "google.protobuf.Duration":
- return e.marshalDuration(m)
-
- case "google.protobuf.Timestamp":
+ case detectknown.TimestampProto:
return e.marshalTimestamp(m)
-
- case "google.protobuf.FieldMask":
+ case detectknown.DurationProto:
+ return e.marshalDuration(m)
+ case detectknown.WrappersProto:
+ return e.marshalWrapperType(m)
+ case detectknown.StructProto:
+ return e.marshalStructType(m)
+ case detectknown.FieldMaskProto:
return e.marshalFieldMask(m)
+ case detectknown.EmptyProto:
+ return e.marshalEmpty(m)
+ default:
+ panic(fmt.Sprintf("%s does not have a custom marshaler", name))
}
-
- panic(fmt.Sprintf("%s does not have a custom marshaler", name))
}
// unmarshalCustomType unmarshals given well-known type message that have
@@ -96,44 +68,24 @@
// isCustomType returns true, else it will panic.
func (d decoder) unmarshalCustomType(m pref.Message) error {
name := m.Descriptor().FullName()
- switch name {
- case "google.protobuf.Any":
+ switch detectknown.Which(name) {
+ case detectknown.AnyProto:
return d.unmarshalAny(m)
-
- case "google.protobuf.BoolValue",
- "google.protobuf.DoubleValue",
- "google.protobuf.FloatValue",
- "google.protobuf.Int32Value",
- "google.protobuf.Int64Value",
- "google.protobuf.UInt32Value",
- "google.protobuf.UInt64Value",
- "google.protobuf.StringValue",
- "google.protobuf.BytesValue":
- return d.unmarshalWrapperType(m)
-
- case "google.protobuf.Empty":
- return d.unmarshalEmpty(m)
-
- case "google.protobuf.Struct":
- return d.unmarshalStruct(m)
-
- case "google.protobuf.ListValue":
- return d.unmarshalListValue(m)
-
- case "google.protobuf.Value":
- return d.unmarshalKnownValue(m)
-
- case "google.protobuf.Duration":
- return d.unmarshalDuration(m)
-
- case "google.protobuf.Timestamp":
+ case detectknown.TimestampProto:
return d.unmarshalTimestamp(m)
-
- case "google.protobuf.FieldMask":
+ case detectknown.DurationProto:
+ return d.unmarshalDuration(m)
+ case detectknown.WrappersProto:
+ return d.unmarshalWrapperType(m)
+ case detectknown.StructProto:
+ return d.unmarshalStructType(m)
+ case detectknown.FieldMaskProto:
return d.unmarshalFieldMask(m)
+ case detectknown.EmptyProto:
+ return d.unmarshalEmpty(m)
+ default:
+ panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
}
-
- panic(fmt.Sprintf("%s does not have a custom unmarshaler", name))
}
// The JSON representation of an Any message uses the regular representation of
@@ -501,6 +453,32 @@
}
}
+func (e encoder) marshalStructType(m pref.Message) error {
+ switch m.Descriptor().Name() {
+ case "Struct":
+ return e.marshalStruct(m)
+ case "ListValue":
+ return e.marshalListValue(m)
+ case "Value":
+ return e.marshalKnownValue(m)
+ default:
+ panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
+ }
+}
+
+func (d decoder) unmarshalStructType(m pref.Message) error {
+ switch m.Descriptor().Name() {
+ case "Struct":
+ return d.unmarshalStruct(m)
+ case "ListValue":
+ return d.unmarshalListValue(m)
+ case "Value":
+ return d.unmarshalKnownValue(m)
+ default:
+ panic(fmt.Sprintf("invalid struct type: %v", m.Descriptor().FullName()))
+ }
+}
+
// The JSON representation for Struct is a JSON object that contains the encoded
// Struct.fields map and follows the serialization rules for a map.
diff --git a/internal/detectknown/detect.go b/internal/detectknown/detect.go
new file mode 100644
index 0000000..091c423
--- /dev/null
+++ b/internal/detectknown/detect.go
@@ -0,0 +1,47 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package detectknown provides functionality for detecting well-known types
+// and identifying them by name.
+package detectknown
+
+import "google.golang.org/protobuf/reflect/protoreflect"
+
+type ProtoFile int
+
+const (
+ Unknown ProtoFile = iota
+ AnyProto
+ TimestampProto
+ DurationProto
+ WrappersProto
+ StructProto
+ FieldMaskProto
+ EmptyProto
+)
+
+var wellKnownTypes = map[protoreflect.FullName]ProtoFile{
+ "google.protobuf.Any": AnyProto,
+ "google.protobuf.Timestamp": TimestampProto,
+ "google.protobuf.Duration": DurationProto,
+ "google.protobuf.BoolValue": WrappersProto,
+ "google.protobuf.Int32Value": WrappersProto,
+ "google.protobuf.Int64Value": WrappersProto,
+ "google.protobuf.UInt32Value": WrappersProto,
+ "google.protobuf.UInt64Value": WrappersProto,
+ "google.protobuf.FloatValue": WrappersProto,
+ "google.protobuf.DoubleValue": WrappersProto,
+ "google.protobuf.BytesValue": WrappersProto,
+ "google.protobuf.StringValue": WrappersProto,
+ "google.protobuf.Struct": StructProto,
+ "google.protobuf.ListValue": StructProto,
+ "google.protobuf.Value": StructProto,
+ "google.protobuf.FieldMask": FieldMaskProto,
+ "google.protobuf.Empty": EmptyProto,
+}
+
+// Which identifies the proto file that a well-known type belongs to.
+func Which(s protoreflect.FullName) ProtoFile {
+ return wellKnownTypes[s]
+}
diff --git a/internal/detectknown/detect_test.go b/internal/detectknown/detect_test.go
new file mode 100644
index 0000000..c9a31c9
--- /dev/null
+++ b/internal/detectknown/detect_test.go
@@ -0,0 +1,58 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package detectknown_test
+
+import (
+ "testing"
+
+ "google.golang.org/protobuf/internal/detectknown"
+ "google.golang.org/protobuf/reflect/protoreflect"
+
+ fieldmaskpb "google.golang.org/protobuf/internal/testprotos/fieldmaskpb"
+ "google.golang.org/protobuf/types/descriptorpb"
+ "google.golang.org/protobuf/types/known/anypb"
+ "google.golang.org/protobuf/types/known/durationpb"
+ "google.golang.org/protobuf/types/known/emptypb"
+ "google.golang.org/protobuf/types/known/structpb"
+ "google.golang.org/protobuf/types/known/timestamppb"
+ "google.golang.org/protobuf/types/known/wrapperspb"
+ "google.golang.org/protobuf/types/pluginpb"
+)
+
+func TestWhich(t *testing.T) {
+ tests := []struct {
+ in protoreflect.FileDescriptor
+ want detectknown.ProtoFile
+ }{
+ {descriptorpb.File_google_protobuf_descriptor_proto, detectknown.Unknown},
+ {pluginpb.File_google_protobuf_compiler_plugin_proto, detectknown.Unknown},
+ {anypb.File_google_protobuf_any_proto, detectknown.AnyProto},
+ {timestamppb.File_google_protobuf_timestamp_proto, detectknown.TimestampProto},
+ {durationpb.File_google_protobuf_duration_proto, detectknown.DurationProto},
+ {wrapperspb.File_google_protobuf_wrappers_proto, detectknown.WrappersProto},
+ {structpb.File_google_protobuf_struct_proto, detectknown.StructProto},
+ {fieldmaskpb.File_google_protobuf_field_mask_proto, detectknown.FieldMaskProto},
+ {emptypb.File_google_protobuf_empty_proto, detectknown.EmptyProto},
+ }
+
+ for _, tt := range tests {
+ rangeMessages(tt.in.Messages(), func(md protoreflect.MessageDescriptor) {
+ got := detectknown.Which(md.FullName())
+ if got != tt.want {
+ t.Errorf("Which(%s) = %v, want %v", md.FullName(), got, tt.want)
+ }
+ })
+ }
+}
+
+func rangeMessages(mds protoreflect.MessageDescriptors, f func(protoreflect.MessageDescriptor)) {
+ for i := 0; i < mds.Len(); i++ {
+ md := mds.Get(i)
+ if !md.IsMapEntry() {
+ f(md)
+ }
+ rangeMessages(md.Messages(), f)
+ }
+}
diff --git a/internal/msgfmt/format.go b/internal/msgfmt/format.go
index 21023e5..c2c856f 100644
--- a/internal/msgfmt/format.go
+++ b/internal/msgfmt/format.go
@@ -18,6 +18,7 @@
"time"
"google.golang.org/protobuf/encoding/protowire"
+ "google.golang.org/protobuf/internal/detectknown"
"google.golang.org/protobuf/internal/detrand"
"google.golang.org/protobuf/internal/mapsort"
"google.golang.org/protobuf/proto"
@@ -102,13 +103,9 @@
func appendKnownMessage(b []byte, m protoreflect.Message) []byte {
md := m.Descriptor()
- if md.FullName().Parent() != "google.protobuf" {
- return nil
- }
-
fds := md.Fields()
- switch md.Name() {
- case "Any":
+ switch detectknown.Which(md.FullName()) {
+ case detectknown.AnyProto:
var msgVal protoreflect.Message
url := m.Get(fds.ByName("type_url")).String()
if v := reflect.ValueOf(m); v.Type().ConvertibleTo(protocmpMessageType) {
@@ -140,7 +137,7 @@
b = append(b, '}')
return b
- case "Timestamp":
+ case detectknown.TimestampProto:
secs := m.Get(fds.ByName("seconds")).Int()
nanos := m.Get(fds.ByName("nanos")).Int()
if nanos < 0 || nanos >= 1e9 {
@@ -153,7 +150,7 @@
x = strings.TrimSuffix(x, ".000")
return append(b, x+"Z"...)
- case "Duration":
+ case detectknown.DurationProto:
secs := m.Get(fds.ByName("seconds")).Int()
nanos := m.Get(fds.ByName("nanos")).Int()
if nanos <= -1e9 || nanos >= 1e9 || (secs > 0 && nanos < 0) || (secs < 0 && nanos > 0) {
@@ -165,7 +162,7 @@
x = strings.TrimSuffix(x, ".000")
return append(b, x+"s"...)
- case "BoolValue", "Int32Value", "Int64Value", "UInt32Value", "UInt64Value", "FloatValue", "DoubleValue", "StringValue", "BytesValue":
+ case detectknown.WrappersProto:
fd := fds.ByName("value")
return appendValue(b, m.Get(fd), fd)
}