math/big: much simplified and faster Float rounding
Change-Id: Iab0add7aee51a8c72a81f51d980d22d2fd612f5c
Reviewed-on: https://go-review.googlesource.com/20817
Reviewed-by: Alan Donovan <adonovan@google.com>
diff --git a/src/math/big/float.go b/src/math/big/float.go
index f19f21f..4b8ad38 100644
--- a/src/math/big/float.go
+++ b/src/math/big/float.go
@@ -392,15 +392,13 @@
// m > 0 implies z.prec > 0 (checked by validate)
m := uint32(len(z.mant)) // present mantissa length in words
- bits := m * _W // present mantissa bits
+ bits := m * _W // present mantissa bits; bits > 0
if bits <= z.prec {
// mantissa fits => nothing to do
return
}
// bits > z.prec
- n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision
-
// Rounding is based on two bits: the rounding bit (rbit) and the
// sticky bit (sbit). The rbit is the bit immediately before the
// z.prec leading mantissa bits (the "0.5"). The sbit is set if any
@@ -415,111 +413,77 @@
// bits > z.prec: mantissa too large => round
r := uint(bits - z.prec - 1) // rounding bit position; r >= 0
- rbit := z.mant.bit(r) // rounding bit
+ rbit := z.mant.bit(r) & 1 // rounding bit; be safe and ensure it's a single bit
if sbit == 0 {
+ // TODO(gri) if rbit != 0 we don't need to compute sbit for some rounding modes (optimization)
sbit = z.mant.sticky(r)
}
- if debugFloat && sbit&^1 != 0 {
- panic(fmt.Sprintf("invalid sbit %#x", sbit))
- }
-
- // convert ToXInf rounding modes
- mode := z.mode
- switch mode {
- case ToNegativeInf:
- mode = ToZero
- if z.neg {
- mode = AwayFromZero
- }
- case ToPositiveInf:
- mode = AwayFromZero
- if z.neg {
- mode = ToZero
- }
- }
+ sbit &= 1 // be safe and ensure it's a single bit
// cut off extra words
+ n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision
if m > n {
copy(z.mant, z.mant[m-n:]) // move n last words to front
z.mant = z.mant[:n]
}
- // determine number of trailing zero bits t
- t := n*_W - z.prec // 0 <= t < _W
- lsb := Word(1) << t
+ // determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word
+ ntz := n*_W - z.prec // 0 <= ntz < _W
+ lsb := Word(1) << ntz
- // make rounding decision
- // TODO(gri) This can be simplified (see Bits.round in bits_test.go).
- switch mode {
- case ToZero:
- // nothing to do
- case ToNearestEven, ToNearestAway:
- if rbit == 0 {
- // rounding bits == 0b0x
- mode = ToZero
- } else if sbit == 1 {
- // rounding bits == 0b11
- mode = AwayFromZero
- }
- case AwayFromZero:
- if rbit|sbit == 0 {
- mode = ToZero
- }
- default:
- // ToXInf modes have been converted to ToZero or AwayFromZero
- panic("unreachable")
- }
-
- // round and determine accuracy
- switch mode {
- case ToZero:
- if rbit|sbit != 0 {
- z.acc = Below
+ // round if result is inexact
+ if rbit|sbit != 0 {
+ // Make rounding decision: The result mantissa is truncated ("rounded down")
+ // by default. Decide if we need to increment, or "round up", the (unsigned)
+ // mantissa.
+ inc := false
+ switch z.mode {
+ case ToNegativeInf:
+ inc = z.neg
+ case ToZero:
+ // nothing to do
+ case ToNearestEven:
+ inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0)
+ case ToNearestAway:
+ inc = rbit != 0
+ case AwayFromZero:
+ inc = true
+ case ToPositiveInf:
+ inc = !z.neg
+ default:
+ panic("unreachable")
}
- case ToNearestEven, ToNearestAway:
- if debugFloat && rbit != 1 {
- panic("internal error in rounding")
- }
- if mode == ToNearestEven && sbit == 0 && z.mant[0]&lsb == 0 {
- z.acc = Below
- break
- }
- // mode == ToNearestAway || sbit == 1 || z.mant[0]&lsb != 0
- fallthrough
+ // A positive result (!z.neg) is Above the exact result if we increment,
+ // and it's Below if we truncate (Exact results require no rounding).
+ // For a negative result (z.neg) it is exactly the opposite.
+ z.acc = makeAcc(inc != z.neg)
- case AwayFromZero:
- // add 1 to mantissa
- if addVW(z.mant, z.mant, lsb) != 0 {
- // overflow => shift mantissa right by 1 and add msb
- shrVU(z.mant, z.mant, 1)
- z.mant[n-1] |= 1 << (_W - 1)
- // adjust exponent
- if z.exp < MaxExp {
+ if inc {
+ // add 1 to mantissa
+ if addVW(z.mant, z.mant, lsb) != 0 {
+ // mantissa overflow => adjust exponent
+ if z.exp >= MaxExp {
+ // exponent overflow
+ z.form = inf
+ return
+ }
z.exp++
- } else {
- // exponent overflow
- z.acc = makeAcc(!z.neg)
- z.form = inf
- return
+ // adjust mantissa: divide by 2 to compensate for exponent adjustment
+ shrVU(z.mant, z.mant, 1)
+ // set msb == carry == 1 from the mantissa overflow above
+ const msb = 1 << (_W - 1)
+ z.mant[n-1] |= msb
}
}
- z.acc = Above
}
// zero out trailing bits in least-significant word
z.mant[0] &^= lsb - 1
- // update accuracy
- if z.acc != Exact && z.neg {
- z.acc = -z.acc
- }
-
if debugFloat {
z.validate()
}
-
- return
}
func (z *Float) setBits64(neg bool, x uint64) *Float {