math/big: wrap Float.Cmp result in struct to prevent wrong use

Float.Cmp used to return a value < 0, 0, or > 0 depending on how
arguments x, y compared against each other. With the possibility
of NaNs, the result was changed into an Accuracy (to include Undef).
Consequently, Float.Cmp results could still be compared for (in-)
equality with 0, but comparing if < 0 or > 0 would provide the
wrong answer w/o any obvious notice by the compiler.

This change wraps Float.Cmp results into a struct and accessors
are used to access the desired result. This prevents incorrect
use.

Change-Id: I34e6a6c1859251ec99b5cf953e82542025ace56f
Reviewed-on: https://go-review.googlesource.com/7526
Reviewed-by: Rob Pike <r@golang.org>
diff --git a/src/math/big/float.go b/src/math/big/float.go
index 44691c4..d716c8c 100644
--- a/src/math/big/float.go
+++ b/src/math/big/float.go
@@ -1446,6 +1446,10 @@
 	return z
 }
 
+type cmpResult struct {
+	acc Accuracy
+}
+
 // Cmp compares x and y and returns:
 //
 //   Below if x <  y
@@ -1453,44 +1457,45 @@
 //   Above if x >  y
 //   Undef if any of x, y is NaN
 //
-func (x *Float) Cmp(y *Float) Accuracy {
+func (x *Float) Cmp(y *Float) cmpResult {
 	if debugFloat {
 		x.validate()
 		y.validate()
 	}
 
 	if x.form == nan || y.form == nan {
-		return Undef
+		return cmpResult{Undef}
 	}
 
 	mx := x.ord()
 	my := y.ord()
 	switch {
 	case mx < my:
-		return Below
+		return cmpResult{Below}
 	case mx > my:
-		return Above
+		return cmpResult{Above}
 	}
 	// mx == my
 
 	// only if |mx| == 1 we have to compare the mantissae
 	switch mx {
 	case -1:
-		return y.ucmp(x)
+		return cmpResult{y.ucmp(x)}
 	case +1:
-		return x.ucmp(y)
+		return cmpResult{x.ucmp(y)}
 	}
 
-	return Exact
+	return cmpResult{Exact}
 }
 
 // The following accessors simplify testing of Cmp results.
-func (acc Accuracy) Eql() bool { return acc == Exact }
-func (acc Accuracy) Neq() bool { return acc != Exact }
-func (acc Accuracy) Lss() bool { return acc == Below }
-func (acc Accuracy) Leq() bool { return acc&Above == 0 }
-func (acc Accuracy) Gtr() bool { return acc == Above }
-func (acc Accuracy) Geq() bool { return acc&Below == 0 }
+func (res cmpResult) Acc() Accuracy { return res.acc }
+func (res cmpResult) Eql() bool     { return res.acc == Exact }
+func (res cmpResult) Neq() bool     { return res.acc != Exact }
+func (res cmpResult) Lss() bool     { return res.acc == Below }
+func (res cmpResult) Leq() bool     { return res.acc&Above == 0 }
+func (res cmpResult) Gtr() bool     { return res.acc == Above }
+func (res cmpResult) Geq() bool     { return res.acc&Below == 0 }
 
 // ord classifies x and returns:
 //
diff --git a/src/math/big/float_test.go b/src/math/big/float_test.go
index 86b1c6f..683809b 100644
--- a/src/math/big/float_test.go
+++ b/src/math/big/float_test.go
@@ -207,7 +207,7 @@
 	if x.IsNaN() || y.IsNaN() {
 		return x.IsNaN() && y.IsNaN()
 	}
-	return x.Cmp(y) == 0 && x.IsNeg() == y.IsNeg()
+	return x.Cmp(y).Eql() && x.IsNeg() == y.IsNeg()
 }
 
 func TestFloatMantExp(t *testing.T) {
@@ -918,7 +918,7 @@
 		// inverse conversion
 		if res != nil {
 			got := new(Float).SetPrec(64).SetRat(res)
-			if got.Cmp(x) != 0 {
+			if got.Cmp(x).Neq() {
 				t.Errorf("%s: got %s; want %s", test.x, got, x)
 			}
 		}
@@ -995,7 +995,7 @@
 		for i := 0; i < n; i++ {
 			x.Add(&x, &one)
 		}
-		if x.Cmp(new(Float).SetInt64(n)) != 0 {
+		if x.Cmp(new(Float).SetInt64(n)).Neq() {
 			t.Errorf("prec = %d: got %s; want %d", prec, &x, n)
 		}
 	}
@@ -1036,14 +1036,14 @@
 					got := new(Float).SetPrec(prec).SetMode(mode)
 					got.Add(x, y)
 					want := zbits.round(prec, mode)
-					if got.Cmp(want) != 0 {
+					if got.Cmp(want).Neq() {
 						t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t+    %s %v\n\t=    %s\n\twant %s",
 							i, prec, mode, x, xbits, y, ybits, got, want)
 					}
 
 					got.Sub(z, x)
 					want = ybits.round(prec, mode)
-					if got.Cmp(want) != 0 {
+					if got.Cmp(want).Neq() {
 						t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t-    %s %v\n\t=    %s\n\twant %s",
 							i, prec, mode, z, zbits, x, xbits, got, want)
 					}
@@ -1137,7 +1137,7 @@
 					got := new(Float).SetPrec(prec).SetMode(mode)
 					got.Mul(x, y)
 					want := zbits.round(prec, mode)
-					if got.Cmp(want) != 0 {
+					if got.Cmp(want).Neq() {
 						t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t*    %s %v\n\t=    %s\n\twant %s",
 							i, prec, mode, x, xbits, y, ybits, got, want)
 					}
@@ -1147,7 +1147,7 @@
 					}
 					got.Quo(z, x)
 					want = ybits.round(prec, mode)
-					if got.Cmp(want) != 0 {
+					if got.Cmp(want).Neq() {
 						t.Errorf("i = %d, prec = %d, %s:\n\t     %s %v\n\t/    %s %v\n\t=    %s\n\twant %s",
 							i, prec, mode, z, zbits, x, xbits, got, want)
 					}
@@ -1230,7 +1230,7 @@
 		p.Mul(p, psix)
 		z2.Sub(two, p)
 
-		if z1.Cmp(z2) != 0 {
+		if z1.Cmp(z2).Neq() {
 			t.Fatalf("prec %d: got z1 = %s != z2 = %s; want z1 == z2\n", prec, z1, z2)
 		}
 		if z1.Sign() != 0 {
@@ -1281,7 +1281,7 @@
 				prec := uint(preci + d)
 				got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y)
 				want := bits.round(prec, mode)
-				if got.Cmp(want) != 0 {
+				if got.Cmp(want).Neq() {
 					t.Errorf("i = %d, prec = %d, %s:\n\t     %s\n\t/    %s\n\t=    %s\n\twant %s",
 						i, prec, mode, x, y, got, want)
 				}
@@ -1462,7 +1462,7 @@
 			}
 			for _, y := range args {
 				yy.SetFloat64(y)
-				got := xx.Cmp(yy)
+				got := xx.Cmp(yy).Acc()
 				want := Undef
 				switch {
 				case x < y:
diff --git a/src/math/big/floatconv_test.go b/src/math/big/floatconv_test.go
index e7920d0..17c8b14 100644
--- a/src/math/big/floatconv_test.go
+++ b/src/math/big/floatconv_test.go
@@ -102,7 +102,7 @@
 		}
 		f, _ := x.Float64()
 		want := new(Float).SetFloat64(test.x)
-		if x.Cmp(want) != 0 {
+		if x.Cmp(want).Neq() {
 			t.Errorf("%s: got %s (%v); want %v", test.s, &x, f, test.x)
 		}
 	}
diff --git a/src/math/big/floatexample_test.go b/src/math/big/floatexample_test.go
index 181c0bc..5e4a9cb 100644
--- a/src/math/big/floatexample_test.go
+++ b/src/math/big/floatexample_test.go
@@ -63,7 +63,7 @@
 			t := x.Cmp(y)
 			fmt.Printf(
 				"%4s  %4s  %5s   %c    %c    %c    %c    %c    %c\n",
-				x, y, t,
+				x, y, t.Acc(),
 				mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
 		}
 		fmt.Println()