math/big: much simplified and faster Float rounding

Change-Id: Iab0add7aee51a8c72a81f51d980d22d2fd612f5c
Reviewed-on: https://go-review.googlesource.com/20817
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/src/math/big/float.go b/src/math/big/float.go
index f19f21f..4b8ad38 100644
--- a/src/math/big/float.go
+++ b/src/math/big/float.go
@@ -392,15 +392,13 @@
 	// m > 0 implies z.prec > 0 (checked by validate)
 
 	m := uint32(len(z.mant)) // present mantissa length in words
-	bits := m * _W           // present mantissa bits
+	bits := m * _W           // present mantissa bits; bits > 0
 	if bits <= z.prec {
 		// mantissa fits => nothing to do
 		return
 	}
 	// bits > z.prec
 
-	n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision
-
 	// Rounding is based on two bits: the rounding bit (rbit) and the
 	// sticky bit (sbit). The rbit is the bit immediately before the
 	// z.prec leading mantissa bits (the "0.5"). The sbit is set if any
@@ -415,111 +413,77 @@
 
 	// bits > z.prec: mantissa too large => round
 	r := uint(bits - z.prec - 1) // rounding bit position; r >= 0
-	rbit := z.mant.bit(r)        // rounding bit
+	rbit := z.mant.bit(r) & 1    // rounding bit; be safe and ensure it's a single bit
 	if sbit == 0 {
+		// TODO(gri) if rbit != 0 we don't need to compute sbit for some rounding modes (optimization)
 		sbit = z.mant.sticky(r)
 	}
-	if debugFloat && sbit&^1 != 0 {
-		panic(fmt.Sprintf("invalid sbit %#x", sbit))
-	}
-
-	// convert ToXInf rounding modes
-	mode := z.mode
-	switch mode {
-	case ToNegativeInf:
-		mode = ToZero
-		if z.neg {
-			mode = AwayFromZero
-		}
-	case ToPositiveInf:
-		mode = AwayFromZero
-		if z.neg {
-			mode = ToZero
-		}
-	}
+	sbit &= 1 // be safe and ensure it's a single bit
 
 	// cut off extra words
+	n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision
 	if m > n {
 		copy(z.mant, z.mant[m-n:]) // move n last words to front
 		z.mant = z.mant[:n]
 	}
 
-	// determine number of trailing zero bits t
-	t := n*_W - z.prec // 0 <= t < _W
-	lsb := Word(1) << t
+	// determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word
+	ntz := n*_W - z.prec // 0 <= ntz < _W
+	lsb := Word(1) << ntz
 
-	// make rounding decision
-	// TODO(gri) This can be simplified (see Bits.round in bits_test.go).
-	switch mode {
-	case ToZero:
-		// nothing to do
-	case ToNearestEven, ToNearestAway:
-		if rbit == 0 {
-			// rounding bits == 0b0x
-			mode = ToZero
-		} else if sbit == 1 {
-			// rounding bits == 0b11
-			mode = AwayFromZero
-		}
-	case AwayFromZero:
-		if rbit|sbit == 0 {
-			mode = ToZero
-		}
-	default:
-		// ToXInf modes have been converted to ToZero or AwayFromZero
-		panic("unreachable")
-	}
-
-	// round and determine accuracy
-	switch mode {
-	case ToZero:
-		if rbit|sbit != 0 {
-			z.acc = Below
+	// round if result is inexact
+	if rbit|sbit != 0 {
+		// Make rounding decision: The result mantissa is truncated ("rounded down")
+		// by default. Decide if we need to increment, or "round up", the (unsigned)
+		// mantissa.
+		inc := false
+		switch z.mode {
+		case ToNegativeInf:
+			inc = z.neg
+		case ToZero:
+			// nothing to do
+		case ToNearestEven:
+			inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0)
+		case ToNearestAway:
+			inc = rbit != 0
+		case AwayFromZero:
+			inc = true
+		case ToPositiveInf:
+			inc = !z.neg
+		default:
+			panic("unreachable")
 		}
 
-	case ToNearestEven, ToNearestAway:
-		if debugFloat && rbit != 1 {
-			panic("internal error in rounding")
-		}
-		if mode == ToNearestEven && sbit == 0 && z.mant[0]&lsb == 0 {
-			z.acc = Below
-			break
-		}
-		// mode == ToNearestAway || sbit == 1 || z.mant[0]&lsb != 0
-		fallthrough
+		// A positive result (!z.neg) is Above the exact result if we increment,
+		// and it's Below if we truncate (Exact results require no rounding).
+		// For a negative result (z.neg) it is exactly the opposite.
+		z.acc = makeAcc(inc != z.neg)
 
-	case AwayFromZero:
-		// add 1 to mantissa
-		if addVW(z.mant, z.mant, lsb) != 0 {
-			// overflow => shift mantissa right by 1 and add msb
-			shrVU(z.mant, z.mant, 1)
-			z.mant[n-1] |= 1 << (_W - 1)
-			// adjust exponent
-			if z.exp < MaxExp {
+		if inc {
+			// add 1 to mantissa
+			if addVW(z.mant, z.mant, lsb) != 0 {
+				// mantissa overflow => adjust exponent
+				if z.exp >= MaxExp {
+					// exponent overflow
+					z.form = inf
+					return
+				}
 				z.exp++
-			} else {
-				// exponent overflow
-				z.acc = makeAcc(!z.neg)
-				z.form = inf
-				return
+				// adjust mantissa: divide by 2 to compensate for exponent adjustment
+				shrVU(z.mant, z.mant, 1)
+				// set msb == carry == 1 from the mantissa overflow above
+				const msb = 1 << (_W - 1)
+				z.mant[n-1] |= msb
 			}
 		}
-		z.acc = Above
 	}
 
 	// zero out trailing bits in least-significant word
 	z.mant[0] &^= lsb - 1
 
-	// update accuracy
-	if z.acc != Exact && z.neg {
-		z.acc = -z.acc
-	}
-
 	if debugFloat {
 		z.validate()
 	}
-
-	return
 }
 
 func (z *Float) setBits64(neg bool, x uint64) *Float {