reflect/protoreflect: add Value.Equal method

The Value.Equal method compares Values in a way that is
analogous to the reflect.DeepEqual function.

Most of the implementation is moved from the "proto" package,
with proto.Equal implemented in terms of protoreflect.Value.Equal.

Change-Id: Ie9d5f6c073dc49fa59f78385c441db8137de5ec5
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/450775
TryBot-Bypass: Heschi Kreinick <heschi@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Oleksii Prokopchuk <prokopchuk@google.com>
Reviewed-by: Christian Höppner <hoeppi@google.com>
Run-TryBot: Lasse Folger <lassefolger@google.com>
diff --git a/proto/equal.go b/proto/equal.go
index 67948dd..baa1bc1 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -5,30 +5,34 @@
 package proto
 
 import (
-	"bytes"
-	"math"
 	"reflect"
 
-	"google.golang.org/protobuf/encoding/protowire"
 	"google.golang.org/protobuf/reflect/protoreflect"
 )
 
-// Equal reports whether two messages are equal.
+// Equal reports whether two messages are equal,
+// by recursively comparing the fields of the message.
+//
+//   - Bytes fields are equal if they contain identical bytes.
+//     Empty bytes (regardless of nil-ness) are considered equal.
+//
+//   - Floating-point fields are equal if they contain the same value.
+//     Unlike the == operator, a NaN is equal to another NaN.
+//
+//   - Other scalar fields are equal if they contain the same value.
+//
+//   - Message fields are equal if they have
+//     the same set of populated known and extension field values, and
+//     the same set of unknown fields values.
+//
+//   - Lists are equal if they are the same length and
+//     each corresponding element is equal.
+//
+//   - Maps are equal if they have the same set of keys and
+//     the corresponding value for each key is equal.
+//
 // If two messages marshal to the same bytes under deterministic serialization,
 // then Equal is guaranteed to report true.
-//
-// Two messages are equal if they belong to the same message descriptor,
-// have the same set of populated known and extension field values,
-// and the same set of unknown fields values. If either of the top-level
-// messages are invalid, then Equal reports true only if both are invalid.
-//
-// Scalar values are compared with the equivalent of the == operator in Go,
-// except bytes values which are compared using bytes.Equal and
-// floating point values which specially treat NaNs as equal.
-// Message values are compared by recursively calling Equal.
-// Lists are equal if each element value is also equal.
-// Maps are equal if they have the same set of keys, where the pair of values
-// for each key is also equal.
 func Equal(x, y Message) bool {
 	if x == nil || y == nil {
 		return x == nil && y == nil
@@ -42,130 +46,7 @@
 	if mx.IsValid() != my.IsValid() {
 		return false
 	}
-	return equalMessage(mx, my)
-}
-
-// equalMessage compares two messages.
-func equalMessage(mx, my protoreflect.Message) bool {
-	if mx.Descriptor() != my.Descriptor() {
-		return false
-	}
-
-	nx := 0
-	equal := true
-	mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
-		nx++
-		vy := my.Get(fd)
-		equal = my.Has(fd) && equalField(fd, vx, vy)
-		return equal
-	})
-	if !equal {
-		return false
-	}
-	ny := 0
-	my.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
-		ny++
-		return true
-	})
-	if nx != ny {
-		return false
-	}
-
-	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
-}
-
-// equalField compares two fields.
-func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
-	switch {
-	case fd.IsList():
-		return equalList(fd, x.List(), y.List())
-	case fd.IsMap():
-		return equalMap(fd, x.Map(), y.Map())
-	default:
-		return equalValue(fd, x, y)
-	}
-}
-
-// equalMap compares two maps.
-func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool {
-	if x.Len() != y.Len() {
-		return false
-	}
-	equal := true
-	x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
-		vy := y.Get(k)
-		equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
-		return equal
-	})
-	return equal
-}
-
-// equalList compares two lists.
-func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool {
-	if x.Len() != y.Len() {
-		return false
-	}
-	for i := x.Len() - 1; i >= 0; i-- {
-		if !equalValue(fd, x.Get(i), y.Get(i)) {
-			return false
-		}
-	}
-	return true
-}
-
-// equalValue compares two singular values.
-func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
-	switch fd.Kind() {
-	case protoreflect.BoolKind:
-		return x.Bool() == y.Bool()
-	case protoreflect.EnumKind:
-		return x.Enum() == y.Enum()
-	case protoreflect.Int32Kind, protoreflect.Sint32Kind,
-		protoreflect.Int64Kind, protoreflect.Sint64Kind,
-		protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind:
-		return x.Int() == y.Int()
-	case protoreflect.Uint32Kind, protoreflect.Uint64Kind,
-		protoreflect.Fixed32Kind, protoreflect.Fixed64Kind:
-		return x.Uint() == y.Uint()
-	case protoreflect.FloatKind, protoreflect.DoubleKind:
-		fx := x.Float()
-		fy := y.Float()
-		if math.IsNaN(fx) || math.IsNaN(fy) {
-			return math.IsNaN(fx) && math.IsNaN(fy)
-		}
-		return fx == fy
-	case protoreflect.StringKind:
-		return x.String() == y.String()
-	case protoreflect.BytesKind:
-		return bytes.Equal(x.Bytes(), y.Bytes())
-	case protoreflect.MessageKind, protoreflect.GroupKind:
-		return equalMessage(x.Message(), y.Message())
-	default:
-		return x.Interface() == y.Interface()
-	}
-}
-
-// equalUnknown compares unknown fields by direct comparison on the raw bytes
-// of each individual field number.
-func equalUnknown(x, y protoreflect.RawFields) bool {
-	if len(x) != len(y) {
-		return false
-	}
-	if bytes.Equal([]byte(x), []byte(y)) {
-		return true
-	}
-
-	mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
-	my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
-	for len(x) > 0 {
-		fnum, _, n := protowire.ConsumeField(x)
-		mx[fnum] = append(mx[fnum], x[:n]...)
-		x = x[n:]
-	}
-	for len(y) > 0 {
-		fnum, _, n := protowire.ConsumeField(y)
-		my[fnum] = append(my[fnum], y[:n]...)
-		y = y[n:]
-	}
-	return reflect.DeepEqual(mx, my)
+	vx := protoreflect.ValueOfMessage(mx)
+	vy := protoreflect.ValueOfMessage(my)
+	return vx.Equal(vy)
 }
diff --git a/reflect/protoreflect/value_equal.go b/reflect/protoreflect/value_equal.go
new file mode 100644
index 0000000..5916525
--- /dev/null
+++ b/reflect/protoreflect/value_equal.go
@@ -0,0 +1,168 @@
+// Copyright 2022 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package protoreflect
+
+import (
+	"bytes"
+	"fmt"
+	"math"
+	"reflect"
+
+	"google.golang.org/protobuf/encoding/protowire"
+)
+
+// Equal reports whether v1 and v2 are recursively equal.
+//
+//   - Values of different types are always unequal.
+//
+//   - Bytes values are equal if they contain identical bytes.
+//     Empty bytes (regardless of nil-ness) are considered equal.
+//
+//   - Floating point values are equal if they contain the same value.
+//     Unlike the == operator, a NaN is equal to another NaN.
+//
+//   - Enums are equal if they contain the same number.
+//     Since Value does not contain an enum descriptor,
+//     enum values do not consider the type of the enum.
+//
+//   - Other scalar values are equal if they contain the same value.
+//
+//   - Message values are equal if they belong to the same message descriptor,
+//     have the same set of populated known and extension field values,
+//     and the same set of unknown fields values.
+//
+//   - Lists are equal if they are the same length and
+//     each corresponding element is equal.
+//
+//   - Maps are equal if they have the same set of keys and
+//     the corresponding value for each key is equal.
+func (v1 Value) Equal(v2 Value) bool {
+	return equalValue(v1, v2)
+}
+
+func equalValue(x, y Value) bool {
+	eqType := x.typ == y.typ
+	switch x.typ {
+	case nilType:
+		return eqType
+	case boolType:
+		return eqType && x.Bool() == y.Bool()
+	case int32Type, int64Type:
+		return eqType && x.Int() == y.Int()
+	case uint32Type, uint64Type:
+		return eqType && x.Uint() == y.Uint()
+	case float32Type, float64Type:
+		return eqType && equalFloat(x.Float(), y.Float())
+	case stringType:
+		return eqType && x.String() == y.String()
+	case bytesType:
+		return eqType && bytes.Equal(x.Bytes(), y.Bytes())
+	case enumType:
+		return eqType && x.Enum() == y.Enum()
+	default:
+		switch x := x.Interface().(type) {
+		case Message:
+			y, ok := y.Interface().(Message)
+			return ok && equalMessage(x, y)
+		case List:
+			y, ok := y.Interface().(List)
+			return ok && equalList(x, y)
+		case Map:
+			y, ok := y.Interface().(Map)
+			return ok && equalMap(x, y)
+		default:
+			panic(fmt.Sprintf("unknown type: %T", x))
+		}
+	}
+}
+
+// equalFloat compares two floats, where NaNs are treated as equal.
+func equalFloat(x, y float64) bool {
+	if math.IsNaN(x) || math.IsNaN(y) {
+		return math.IsNaN(x) && math.IsNaN(y)
+	}
+	return x == y
+}
+
+// equalMessage compares two messages.
+func equalMessage(mx, my Message) bool {
+	if mx.Descriptor() != my.Descriptor() {
+		return false
+	}
+
+	nx := 0
+	equal := true
+	mx.Range(func(fd FieldDescriptor, vx Value) bool {
+		nx++
+		vy := my.Get(fd)
+		equal = my.Has(fd) && equalValue(vx, vy)
+		return equal
+	})
+	if !equal {
+		return false
+	}
+	ny := 0
+	my.Range(func(fd FieldDescriptor, vx Value) bool {
+		ny++
+		return true
+	})
+	if nx != ny {
+		return false
+	}
+
+	return equalUnknown(mx.GetUnknown(), my.GetUnknown())
+}
+
+// equalList compares two lists.
+func equalList(x, y List) bool {
+	if x.Len() != y.Len() {
+		return false
+	}
+	for i := x.Len() - 1; i >= 0; i-- {
+		if !equalValue(x.Get(i), y.Get(i)) {
+			return false
+		}
+	}
+	return true
+}
+
+// equalMap compares two maps.
+func equalMap(x, y Map) bool {
+	if x.Len() != y.Len() {
+		return false
+	}
+	equal := true
+	x.Range(func(k MapKey, vx Value) bool {
+		vy := y.Get(k)
+		equal = y.Has(k) && equalValue(vx, vy)
+		return equal
+	})
+	return equal
+}
+
+// equalUnknown compares unknown fields by direct comparison on the raw bytes
+// of each individual field number.
+func equalUnknown(x, y RawFields) bool {
+	if len(x) != len(y) {
+		return false
+	}
+	if bytes.Equal([]byte(x), []byte(y)) {
+		return true
+	}
+
+	mx := make(map[FieldNumber]RawFields)
+	my := make(map[FieldNumber]RawFields)
+	for len(x) > 0 {
+		fnum, _, n := protowire.ConsumeField(x)
+		mx[fnum] = append(mx[fnum], x[:n]...)
+		x = x[n:]
+	}
+	for len(y) > 0 {
+		fnum, _, n := protowire.ConsumeField(y)
+		my[fnum] = append(my[fnum], y[:n]...)
+		y = y[n:]
+	}
+	return reflect.DeepEqual(mx, my)
+}
diff --git a/reflect/protoreflect/value_test.go b/reflect/protoreflect/value_test.go
index c6bbb58..d387900 100644
--- a/reflect/protoreflect/value_test.go
+++ b/reflect/protoreflect/value_test.go
@@ -11,10 +11,13 @@
 	"testing"
 )
 
+var (
+	fakeMessage = new(struct{ Message })
+	fakeList    = new(struct{ List })
+	fakeMap     = new(struct{ Map })
+)
+
 func TestValue(t *testing.T) {
-	fakeMessage := new(struct{ Message })
-	fakeList := new(struct{ List })
-	fakeMap := new(struct{ Map })
 
 	tests := []struct {
 		in   Value
@@ -98,6 +101,47 @@
 	}
 }
 
+func TestValueEqual(t *testing.T) {
+	tests := []struct {
+		x, y Value
+		want bool
+	}{
+		{Value{}, Value{}, true},
+		{Value{}, ValueOfBool(true), false},
+		{ValueOfBool(true), ValueOfBool(true), true},
+		{ValueOfBool(true), ValueOfBool(false), false},
+		{ValueOfBool(false), ValueOfInt32(0), false},
+		{ValueOfInt32(0), ValueOfInt32(0), true},
+		{ValueOfInt32(0), ValueOfInt32(1), false},
+		{ValueOfInt32(0), ValueOfInt64(0), false},
+		{ValueOfInt64(123), ValueOfInt64(123), true},
+		{ValueOfFloat64(0), ValueOfFloat64(0), true},
+		{ValueOfFloat64(math.NaN()), ValueOfFloat64(math.NaN()), true},
+		{ValueOfFloat64(math.NaN()), ValueOfFloat64(0), false},
+		{ValueOfFloat64(math.Inf(1)), ValueOfFloat64(math.Inf(1)), true},
+		{ValueOfFloat64(math.Inf(-1)), ValueOfFloat64(math.Inf(1)), false},
+		{ValueOfBytes(nil), ValueOfBytes(nil), true},
+		{ValueOfBytes(nil), ValueOfBytes([]byte{}), true},
+		{ValueOfBytes(nil), ValueOfBytes([]byte{1}), false},
+		{ValueOfEnum(0), ValueOfEnum(0), true},
+		{ValueOfEnum(0), ValueOfEnum(1), false},
+		{ValueOfBool(false), ValueOfMessage(fakeMessage), false},
+		{ValueOfMessage(fakeMessage), ValueOfList(fakeList), false},
+		{ValueOfList(fakeList), ValueOfMap(fakeMap), false},
+		{ValueOfMap(fakeMap), ValueOfMessage(fakeMessage), false},
+
+		// Composite types are not tested here.
+		// See proto.TestEqual.
+	}
+
+	for _, tt := range tests {
+		got := tt.x.Equal(tt.y)
+		if got != tt.want {
+			t.Errorf("(%v).Equal(%v) = %v, want %v", tt.x, tt.y, got, tt.want)
+		}
+	}
+}
+
 func BenchmarkValue(b *testing.B) {
 	const testdata = "The quick brown fox jumped over the lazy dog."
 	var sink1 string