| // Copyright 2024 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package impl |
| |
| import ( |
| "bytes" |
| |
| "google.golang.org/protobuf/encoding/protowire" |
| "google.golang.org/protobuf/reflect/protoreflect" |
| "google.golang.org/protobuf/runtime/protoiface" |
| ) |
| |
| func equal(in protoiface.EqualInput) protoiface.EqualOutput { |
| return protoiface.EqualOutput{Equal: equalMessage(in.MessageA, in.MessageB)} |
| } |
| |
| // equalMessage is a fast-path variant of protoreflect.equalMessage. |
| // It takes advantage of the internal messageState type to avoid |
| // unnecessary allocations, type assertions. |
| func equalMessage(mx, my protoreflect.Message) bool { |
| if mx == nil || my == nil { |
| return mx == my |
| } |
| if mx.Descriptor() != my.Descriptor() { |
| return false |
| } |
| |
| msx, ok := mx.(*messageState) |
| if !ok { |
| return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my)) |
| } |
| msy, ok := my.(*messageState) |
| if !ok { |
| return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my)) |
| } |
| |
| mi := msx.messageInfo() |
| miy := msy.messageInfo() |
| if mi != miy { |
| return protoreflect.ValueOfMessage(mx).Equal(protoreflect.ValueOfMessage(my)) |
| } |
| mi.init() |
| // Compares regular fields |
| // Modified Message.Range code that compares two messages of the same type |
| // while going over the fields. |
| for _, ri := range mi.rangeInfos { |
| var fd protoreflect.FieldDescriptor |
| var vx, vy protoreflect.Value |
| |
| switch ri := ri.(type) { |
| case *fieldInfo: |
| hx := ri.has(msx.pointer()) |
| hy := ri.has(msy.pointer()) |
| if hx != hy { |
| return false |
| } |
| if !hx { |
| continue |
| } |
| fd = ri.fieldDesc |
| vx = ri.get(msx.pointer()) |
| vy = ri.get(msy.pointer()) |
| case *oneofInfo: |
| fnx := ri.which(msx.pointer()) |
| fny := ri.which(msy.pointer()) |
| if fnx != fny { |
| return false |
| } |
| if fnx <= 0 { |
| continue |
| } |
| fi := mi.fields[fnx] |
| fd = fi.fieldDesc |
| vx = fi.get(msx.pointer()) |
| vy = fi.get(msy.pointer()) |
| } |
| |
| if !equalValue(fd, vx, vy) { |
| return false |
| } |
| } |
| |
| // Compare extensions. |
| // This is more complicated because mx or my could have empty/nil extension maps, |
| // however some populated extension map values are equal to nil extension maps. |
| emx := mi.extensionMap(msx.pointer()) |
| emy := mi.extensionMap(msy.pointer()) |
| if emx != nil { |
| for k, x := range *emx { |
| xd := x.Type().TypeDescriptor() |
| xv := x.Value() |
| var y ExtensionField |
| ok := false |
| if emy != nil { |
| y, ok = (*emy)[k] |
| } |
| // We need to treat empty lists as equal to nil values |
| if emy == nil || !ok { |
| if xd.IsList() && xv.List().Len() == 0 { |
| continue |
| } |
| return false |
| } |
| |
| if !equalValue(xd, xv, y.Value()) { |
| return false |
| } |
| } |
| } |
| if emy != nil { |
| // emy may have extensions emx does not have, need to check them as well |
| for k, y := range *emy { |
| if emx != nil { |
| // emx has the field, so we already checked it |
| if _, ok := (*emx)[k]; ok { |
| continue |
| } |
| } |
| // Empty lists are equal to nil |
| if y.Type().TypeDescriptor().IsList() && y.Value().List().Len() == 0 { |
| continue |
| } |
| |
| // Cant be equal if the extension is populated |
| return false |
| } |
| } |
| |
| return equalUnknown(mx.GetUnknown(), my.GetUnknown()) |
| } |
| |
| func equalValue(fd protoreflect.FieldDescriptor, vx, vy protoreflect.Value) bool { |
| // slow path |
| if fd.Kind() != protoreflect.MessageKind { |
| return vx.Equal(vy) |
| } |
| |
| // fast path special cases |
| if fd.IsMap() { |
| if fd.MapValue().Kind() == protoreflect.MessageKind { |
| return equalMessageMap(vx.Map(), vy.Map()) |
| } |
| return vx.Equal(vy) |
| } |
| |
| if fd.IsList() { |
| return equalMessageList(vx.List(), vy.List()) |
| } |
| |
| return equalMessage(vx.Message(), vy.Message()) |
| } |
| |
| // Mostly copied from protoreflect.equalMap. |
| // This variant only works for messages as map types. |
| // All other map types should be handled via Value.Equal. |
| func equalMessageMap(mx, my protoreflect.Map) bool { |
| if mx.Len() != my.Len() { |
| return false |
| } |
| equal := true |
| mx.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool { |
| if !my.Has(k) { |
| equal = false |
| return false |
| } |
| vy := my.Get(k) |
| equal = equalMessage(vx.Message(), vy.Message()) |
| return equal |
| }) |
| return equal |
| } |
| |
| // Mostly copied from protoreflect.equalList. |
| // The only change is the usage of equalImpl instead of protoreflect.equalValue. |
| func equalMessageList(lx, ly protoreflect.List) bool { |
| if lx.Len() != ly.Len() { |
| return false |
| } |
| for i := 0; i < lx.Len(); i++ { |
| // We only operate on messages here since equalImpl will not call us in any other case. |
| if !equalMessage(lx.Get(i).Message(), ly.Get(i).Message()) { |
| return false |
| } |
| } |
| return true |
| } |
| |
| // equalUnknown compares unknown fields by direct comparison on the raw bytes |
| // of each individual field number. |
| // Copied from protoreflect.equalUnknown. |
| func equalUnknown(x, y protoreflect.RawFields) bool { |
| if len(x) != len(y) { |
| return false |
| } |
| if bytes.Equal([]byte(x), []byte(y)) { |
| return true |
| } |
| |
| mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields) |
| my := make(map[protoreflect.FieldNumber]protoreflect.RawFields) |
| for len(x) > 0 { |
| fnum, _, n := protowire.ConsumeField(x) |
| mx[fnum] = append(mx[fnum], x[:n]...) |
| x = x[n:] |
| } |
| for len(y) > 0 { |
| fnum, _, n := protowire.ConsumeField(y) |
| my[fnum] = append(my[fnum], y[:n]...) |
| y = y[n:] |
| } |
| if len(mx) != len(my) { |
| return false |
| } |
| |
| for k, v1 := range mx { |
| if v2, ok := my[k]; !ok || !bytes.Equal([]byte(v1), []byte(v2)) { |
| return false |
| } |
| } |
| |
| return true |
| } |