blob: 7a0b49b18eb706a55942b7eabf02702b4f1604c1 [file] [log] [blame]
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protocmp
import (
"bytes"
"fmt"
"math"
"reflect"
"strings"
"github.com/google/go-cmp/cmp"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)
var (
enumReflectType = reflect.TypeOf(Enum{})
messageReflectType = reflect.TypeOf(Message{})
)
// IgnoreEnums ignores all enums of the specified types.
// See IgnoreDescriptors with regard to EnumDescriptors for more information.
//
// This must be used in conjunction with Transform.
func IgnoreEnums(enums ...protoreflect.Enum) cmp.Option {
var ds []protoreflect.Descriptor
for _, e := range enums {
ds = append(ds, e.Descriptor())
}
return IgnoreDescriptors(ds...)
}
// IgnoreMessages ignores all messages of the specified types.
// See IgnoreDescriptors with regard to MessageDescriptors for more information.
//
// This must be used in conjunction with Transform.
func IgnoreMessages(messages ...proto.Message) cmp.Option {
var ds []protoreflect.Descriptor
for _, m := range messages {
ds = append(ds, m.ProtoReflect().Descriptor())
}
return IgnoreDescriptors(ds...)
}
// IgnoreFields ignores the specified fields in messages of type m.
// This panics if a field of the given name does not exist.
// See IgnoreDescriptors with regard to FieldDescriptors for more information.
//
// This must be used in conjunction with Transform.
func IgnoreFields(message proto.Message, names ...protoreflect.Name) cmp.Option {
var ds []protoreflect.Descriptor
md := message.ProtoReflect().Descriptor()
for _, s := range names {
ds = append(ds, mustFindFieldDescriptor(md, s))
}
return IgnoreDescriptors(ds...)
}
// IgnoreOneofs ignores fields in the specified oneofs in messages of type m.
// This panics if a oneof of the given name does not exist.
// See IgnoreDescriptors with regard to OneofDescriptors for more information.
//
// This must be used in conjunction with Transform.
func IgnoreOneofs(message proto.Message, names ...protoreflect.Name) cmp.Option {
var ds []protoreflect.Descriptor
md := message.ProtoReflect().Descriptor()
for _, s := range names {
ds = append(ds, mustFindOneofDescriptor(md, s))
}
return IgnoreDescriptors(ds...)
}
// IgnoreDescriptors ignores the specified set of descriptors.
// The following descriptor types may be specified:
//
// • EnumDescriptor: Enums of this type or messages containing singular fields,
// list fields, or map fields with enum values of this type are ignored.
// Enums are matched based on their full name.
//
// • MessageDescriptor: Messages of this type or messages containing
// singular fields, list fields, or map fields with message values of this type
// are ignored. Messages are matched based on their full name.
//
// • ExtensionDescriptor: Extensions fields that match the given descriptor
// by full name are ignored.
//
// • FieldDescriptor: Message fields that match the given descriptor
// by full name are ignored.
//
// • OneofDescriptor: Message fields that match the set of fields in the given
// oneof descriptor by full name are ignored.
//
// This must be used in conjunction with Transform.
func IgnoreDescriptors(descs ...protoreflect.Descriptor) cmp.Option {
return cmp.FilterPath(newNameFilters(descs...).Filter, cmp.Ignore())
}
func mustFindFieldDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.FieldDescriptor {
d := findDescriptor(md, s)
if fd, ok := d.(protoreflect.FieldDescriptor); ok && fd.Name() == s {
return fd
}
var suggestion string
switch d.(type) {
case protoreflect.FieldDescriptor:
suggestion = fmt.Sprintf("; consider specifying field %q instead", d.Name())
case protoreflect.OneofDescriptor:
suggestion = fmt.Sprintf("; consider specifying oneof %q with IgnoreOneofs instead", d.Name())
}
panic(fmt.Sprintf("message %q has no field %q%s", md.FullName(), s, suggestion))
}
func mustFindOneofDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.OneofDescriptor {
d := findDescriptor(md, s)
if od, ok := d.(protoreflect.OneofDescriptor); ok && d.Name() == s {
return od
}
var suggestion string
switch d.(type) {
case protoreflect.OneofDescriptor:
suggestion = fmt.Sprintf("; consider specifying oneof %q instead", d.Name())
case protoreflect.FieldDescriptor:
suggestion = fmt.Sprintf("; consider specifying field %q with IgnoreFields instead", d.Name())
}
panic(fmt.Sprintf("message %q has no oneof %q%s", md.FullName(), s, suggestion))
}
func findDescriptor(md protoreflect.MessageDescriptor, s protoreflect.Name) protoreflect.Descriptor {
// Exact match.
if fd := md.Fields().ByName(s); fd != nil {
return fd
}
if od := md.Oneofs().ByName(s); od != nil {
return od
}
// Best-effort match.
//
// It's a common user mistake to use the CameCased field name as it appears
// in the generated Go struct. Instead of complaining that it doesn't exist,
// suggest the real protobuf name that the user may have desired.
normalize := func(s protoreflect.Name) string {
return strings.Replace(strings.ToLower(string(s)), "_", "", -1)
}
for i := 0; i < md.Fields().Len(); i++ {
if fd := md.Fields().Get(i); normalize(fd.Name()) == normalize(s) {
return fd
}
}
for i := 0; i < md.Oneofs().Len(); i++ {
if od := md.Oneofs().Get(i); normalize(od.Name()) == normalize(s) {
return od
}
}
return nil
}
type nameFilters struct {
names map[protoreflect.FullName]bool
}
func newNameFilters(descs ...protoreflect.Descriptor) *nameFilters {
f := &nameFilters{names: make(map[protoreflect.FullName]bool)}
for _, d := range descs {
switch d := d.(type) {
case protoreflect.EnumDescriptor:
f.names[d.FullName()] = true
case protoreflect.MessageDescriptor:
f.names[d.FullName()] = true
case protoreflect.FieldDescriptor:
f.names[d.FullName()] = true
case protoreflect.OneofDescriptor:
for i := 0; i < d.Fields().Len(); i++ {
f.names[d.Fields().Get(i).FullName()] = true
}
default:
panic("invalid descriptor type")
}
}
return f
}
func (f *nameFilters) Filter(p cmp.Path) bool {
vx, vy := p.Last().Values()
return (f.filterValue(vx) && f.filterValue(vy)) || f.filterFields(p)
}
func (f *nameFilters) filterFields(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check field name.
vx, vy := ps.Values()
mx := vx.Interface().(Message)
my := vy.Interface().(Message)
k := mi.Key().String()
if f.filterFieldName(mx, k) && f.filterFieldName(my, k) {
return true
}
// Check field value.
vx, vy = mi.Values()
if f.filterFieldValue(vx) && f.filterFieldValue(vy) {
return true
}
return false
}
func (f *nameFilters) filterFieldName(m Message, k string) bool {
if md := m.Descriptor(); md != nil {
switch {
case protoreflect.Name(k).IsValid():
return f.names[md.Fields().ByName(protoreflect.Name(k)).FullName()]
case strings.HasPrefix(k, "[") && strings.HasSuffix(k, "]"):
return f.names[protoreflect.FullName(k[1:len(k)-1])]
}
}
return false
}
func (f *nameFilters) filterFieldValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
v = v.Elem() // map entries are always populated values
switch t := v.Type(); {
case t == enumReflectType || t == messageReflectType:
// Check for singular message or enum field.
return f.filterValue(v)
case t.Kind() == reflect.Slice && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
// Check for list field of enum or message type.
return f.filterValue(v.Index(0))
case t.Kind() == reflect.Map && (t.Elem() == enumReflectType || t.Elem() == messageReflectType):
// Check for map field of enum or message type.
return f.filterValue(v.MapIndex(v.MapKeys()[0]))
}
return false
}
func (f *nameFilters) filterValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
if !v.CanInterface() {
return false // implies unexported struct field
}
switch v := v.Interface().(type) {
case Enum:
return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
case Message:
return v.Descriptor() != nil && f.names[v.Descriptor().FullName()]
}
return false
}
// IgnoreDefaultScalars ignores singular scalars that are unpopulated or
// explicitly set to the default value.
// This option does not effect elements in a list or entries in a map.
//
// This must be used in conjunction with Transform.
func IgnoreDefaultScalars() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check whether both fields are default or unpopulated scalars.
vx, vy := ps.Values()
mx := vx.Interface().(Message)
my := vy.Interface().(Message)
k := mi.Key().String()
return isDefaultScalar(mx, k) && isDefaultScalar(my, k)
}, cmp.Ignore())
}
func isDefaultScalar(m Message, k string) bool {
if _, ok := m[k]; !ok {
return true
}
var fd protoreflect.FieldDescriptor
switch mt := m[messageTypeKey].(messageType); {
case protoreflect.Name(k).IsValid():
fd = mt.md.Fields().ByName(protoreflect.Name(k))
case strings.HasPrefix(k, "[") && strings.HasSuffix(k, "]"):
fd = mt.xds[protoreflect.FullName(k[1:len(k)-1])]
}
if fd == nil || !fd.Default().IsValid() {
return false
}
switch fd.Kind() {
case protoreflect.BytesKind:
v, ok := m[k].([]byte)
return ok && bytes.Equal(fd.Default().Bytes(), v)
case protoreflect.FloatKind:
v, ok := m[k].(float32)
return ok && equalFloat64(fd.Default().Float(), float64(v))
case protoreflect.DoubleKind:
v, ok := m[k].(float64)
return ok && equalFloat64(fd.Default().Float(), float64(v))
case protoreflect.EnumKind:
v, ok := m[k].(Enum)
return ok && fd.Default().Enum() == v.Number()
default:
return reflect.DeepEqual(fd.Default().Interface(), m[k])
}
}
func equalFloat64(x, y float64) bool {
return x == y || (math.IsNaN(x) && math.IsNaN(y))
}
// IgnoreEmptyMessages ignores messages that are empty or unpopulated.
// It applies to standalone Messages, singular message fields,
// list fields of messages, and map fields of message values.
//
// This must be used in conjunction with Transform.
func IgnoreEmptyMessages() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
vx, vy := p.Last().Values()
return (isEmptyMessage(vx) && isEmptyMessage(vy)) || isEmptyMessageFields(p)
}, cmp.Ignore())
}
func isEmptyMessageFields(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Check field value.
vx, vy := mi.Values()
if isEmptyMessageFieldValue(vx) && isEmptyMessageFieldValue(vy) {
return true
}
return false
}
func isEmptyMessageFieldValue(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
v = v.Elem() // map entries are always populated values
switch t := v.Type(); {
case t == messageReflectType:
// Check singular field for empty message.
if !isEmptyMessage(v) {
return false
}
case t.Kind() == reflect.Slice && t.Elem() == messageReflectType:
// Check list field for all empty message elements.
for i := 0; i < v.Len(); i++ {
if !isEmptyMessage(v.Index(i)) {
return false
}
}
case t.Kind() == reflect.Map && t.Elem() == messageReflectType:
// Check map field for all empty message values.
for _, k := range v.MapKeys() {
if !isEmptyMessage(v.MapIndex(k)) {
return false
}
}
default:
return false
}
return true
}
func isEmptyMessage(v reflect.Value) bool {
if !v.IsValid() {
return true // implies missing slice element or map entry
}
if !v.CanInterface() {
return false // implies unexported struct field
}
if m, ok := v.Interface().(Message); ok {
return len(m) == 0 || (len(m) == 1 && m[messageTypeKey] != nil)
}
return false
}
// IgnoreUnknown ignores unknown fields in all messages.
//
// This must be used in conjunction with Transform.
func IgnoreUnknown() cmp.Option {
return cmp.FilterPath(func(p cmp.Path) bool {
// Filter for Message maps.
mi, ok := p.Index(-1).(cmp.MapIndex)
if !ok {
return false
}
ps := p.Index(-2)
if ps.Type() != messageReflectType {
return false
}
// Filter for unknown fields (which always have a numeric map key).
return strings.Trim(mi.Key().String(), "0123456789") == ""
}, cmp.Ignore())
}