big: use fast shift routines
- fixed a couple of bugs in the process
(shift right was incorrect for negative numbers)
- added more tests and made some tests more robust
- changed pidigits back to using shifts to multiply
by 2 instead of add
This improves pidigit -s -n 10000 by approx. 5%:
user 0m6.496s (old)
user 0m6.156s (new)
R=rsc
CC=golang-dev
https://golang.org/cl/963044
diff --git a/src/pkg/big/nat.go b/src/pkg/big/nat.go
index 2db9e59..ff8e806 100644
--- a/src/pkg/big/nat.go
+++ b/src/pkg/big/nat.go
@@ -554,8 +554,8 @@
// D1.
shift := uint(leadingZeroBits(v[n-1]))
- v.shiftLeft(v, shift)
- u.shiftLeft(uIn, shift)
+ v.shiftLeftDeprecated(v, shift)
+ u.shiftLeftDeprecated(uIn, shift)
u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
// D2.
@@ -597,8 +597,8 @@
}
q = q.norm()
- u.shiftRight(u, shift)
- v.shiftRight(v, shift)
+ u.shiftRightDeprecated(u, shift)
+ v.shiftRightDeprecated(v, shift)
r = u.norm()
return q, r
@@ -780,12 +780,56 @@
}
-// TODO(gri) Make the shift routines faster.
-// Use pidigits.go benchmark as a test case.
+// z = x << s
+func (z nat) shl(x nat, s uint) nat {
+ m := len(x)
+ if m == 0 {
+ return z.make(0)
+ }
+ // m > 0
+ // determine if z can be reused
+ // TODO(gri) change shlVW so we don't need this
+ if len(z) > 0 && alias(z, x) {
+ z = nil // z is an alias for x - cannot reuse
+ }
+
+ n := m + int(s/_W)
+ z = z.make(n + 1)
+ z[n] = shlVW(&z[n-m], &x[0], Word(s%_W), m)
+
+ return z.norm()
+}
+
+
+// z = x >> s
+func (z nat) shr(x nat, s uint) nat {
+ m := len(x)
+ n := m - int(s/_W)
+ if n <= 0 {
+ return z.make(0)
+ }
+ // n > 0
+
+ // determine if z can be reused
+ // TODO(gri) change shrVW so we don't need this
+ if len(z) > 0 && alias(z, x) {
+ z = nil // z is an alias for x - cannot reuse
+ }
+
+ z = z.make(n)
+ shrVW(&z[0], &x[m-n], Word(s%_W), m)
+
+ return z.norm()
+}
+
+
+// TODO(gri) Remove these shift functions once shlVW and shrVW can be
+// used directly in divLarge and powersOfTwoDecompose
+//
// To avoid losing the top n bits, z should be sized so that
// len(z) == len(x) + 1.
-func (z nat) shiftLeft(x nat, n uint) nat {
+func (z nat) shiftLeftDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@@ -805,7 +849,7 @@
}
-func (z nat) shiftRight(x nat, n uint) nat {
+func (z nat) shiftRightDeprecated(x nat, n uint) nat {
if len(x) == 0 {
return x
}
@@ -850,7 +894,7 @@
x := trailingZeroBits(n[zeroWords])
q = q.make(len(n) - zeroWords)
- q.shiftRight(n[zeroWords:], uint(x))
+ q.shiftRightDeprecated(n[zeroWords:], uint(x))
q = q.norm()
k = Word(_W*zeroWords + x)