math/big: specialize Karatsuba implementation for squaring

Currently we use three different algorithms for squaring:
1. basic multiplication for small numbers
2. basic squaring for medium numbers
3. Karatsuba multiplication for large numbers

Change 3. to a version of Karatsuba multiplication specialized
for x == y.

Increasing the performance of 3. lets us lower the threshold
between 2. and 3.

Adapt TestCalibrate to the change that 3. isn't independent
of the threshold between 1. and 2. any more.

Fixes #23221.

benchstat old.txt new.txt
name           old time/op  new time/op  delta
NatSqr/1-4     29.6ns ± 7%  29.5ns ± 5%     ~     (p=0.103 n=50+50)
NatSqr/2-4     51.9ns ± 1%  51.9ns ± 1%     ~     (p=0.693 n=42+49)
NatSqr/3-4     64.3ns ± 1%  64.1ns ± 0%   -0.26%  (p=0.000 n=46+43)
NatSqr/5-4     93.5ns ± 2%  93.1ns ± 1%   -0.39%  (p=0.000 n=48+49)
NatSqr/8-4      131ns ± 1%   131ns ± 1%     ~     (p=0.870 n=46+49)
NatSqr/10-4     175ns ± 1%   175ns ± 1%   +0.38%  (p=0.000 n=49+47)
NatSqr/20-4     426ns ± 1%   429ns ± 1%   +0.84%  (p=0.000 n=46+48)
NatSqr/30-4     702ns ± 2%   699ns ± 1%   -0.38%  (p=0.011 n=46+44)
NatSqr/50-4    1.44µs ± 2%  1.43µs ± 1%   -0.54%  (p=0.010 n=48+48)
NatSqr/80-4    2.85µs ± 1%  2.87µs ± 1%   +0.68%  (p=0.000 n=47+47)
NatSqr/100-4   4.06µs ± 1%  4.07µs ± 1%   +0.29%  (p=0.000 n=46+45)
NatSqr/200-4   13.4µs ± 1%  13.5µs ± 1%   +0.73%  (p=0.000 n=48+48)
NatSqr/300-4   28.5µs ± 1%  28.2µs ± 1%   -1.22%  (p=0.000 n=46+48)
NatSqr/500-4   81.9µs ± 1%  67.0µs ± 1%  -18.25%  (p=0.000 n=48+48)
NatSqr/800-4    161µs ± 1%   140µs ± 1%  -13.29%  (p=0.000 n=47+48)
NatSqr/1000-4   245µs ± 1%   207µs ± 1%  -15.17%  (p=0.000 n=49+49)

go test -v -calibrate --run TestCalibrate
...
Calibrating threshold between basicSqr(x) and karatsubaSqr(x)
Looking for a timing difference for x between 200 - 500 words by 10 step
words = 200 deltaT =     -980ns (  -7%) is karatsubaSqr(x) better: false
words = 210 deltaT =     -773ns (  -5%) is karatsubaSqr(x) better: false
words = 220 deltaT =     -695ns (  -4%) is karatsubaSqr(x) better: false
words = 230 deltaT =     -570ns (  -3%) is karatsubaSqr(x) better: false
words = 240 deltaT =     -458ns (  -2%) is karatsubaSqr(x) better: false
words = 250 deltaT =      -63ns (   0%) is karatsubaSqr(x) better: false
words = 260 deltaT =      118ns (   0%) is karatsubaSqr(x) better: true  threshold  found
words = 270 deltaT =      377ns (   1%) is karatsubaSqr(x) better: true
words = 280 deltaT =      765ns (   3%) is karatsubaSqr(x) better: true
words = 290 deltaT =      673ns (   2%) is karatsubaSqr(x) better: true
words = 300 deltaT =      502ns (   1%) is karatsubaSqr(x) better: true
words = 310 deltaT =      629ns (   2%) is karatsubaSqr(x) better: true
words = 320 deltaT =    1.011µs (   3%) is karatsubaSqr(x) better: true
words = 330 deltaT =     1.36µs (   4%) is karatsubaSqr(x) better: true
words = 340 deltaT =    3.001µs (   8%) is karatsubaSqr(x) better: true
words = 350 deltaT =    3.178µs (   8%) is karatsubaSqr(x) better: true
...

Change-Id: I6f13c23d94d042539ac28e77fd2618cdc37a429e
Reviewed-on: https://go-review.googlesource.com/105075
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
diff --git a/src/math/big/calibrate_test.go b/src/math/big/calibrate_test.go
index 2b96e74..4fa663f 100644
--- a/src/math/big/calibrate_test.go
+++ b/src/math/big/calibrate_test.go
@@ -28,24 +28,32 @@
 
 var calibrate = flag.Bool("calibrate", false, "run calibration test")
 
-func TestCalibrate(t *testing.T) {
-	if *calibrate {
-		computeKaratsubaThresholds()
+const (
+	sqrModeMul       = "mul(x, x)"
+	sqrModeBasic     = "basicSqr(x)"
+	sqrModeKaratsuba = "karatsubaSqr(x)"
+)
 
-		// compute basicSqrThreshold where overhead becomes negligible
-		minSqr := computeSqrThreshold(10, 30, 1, 3)
-		// compute karatsubaSqrThreshold where karatsuba is faster
-		maxSqr := computeSqrThreshold(300, 500, 10, 3)
-		if minSqr != 0 {
-			fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
-		} else {
-			fmt.Println("no basicSqrThreshold found")
-		}
-		if maxSqr != 0 {
-			fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
-		} else {
-			fmt.Println("no karatsubaSqrThreshold found")
-		}
+func TestCalibrate(t *testing.T) {
+	if !*calibrate {
+		return
+	}
+
+	computeKaratsubaThresholds()
+
+	// compute basicSqrThreshold where overhead becomes negligible
+	minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
+	// compute karatsubaSqrThreshold where karatsuba is faster
+	maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
+	if minSqr != 0 {
+		fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
+	} else {
+		fmt.Println("no basicSqrThreshold found")
+	}
+	if maxSqr != 0 {
+		fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
+	} else {
+		fmt.Println("no karatsubaSqrThreshold found")
 	}
 }
 
@@ -109,16 +117,17 @@
 	}
 }
 
-func measureBasicSqr(words, nruns int, enable bool) time.Duration {
+func measureSqr(words, nruns int, mode string) time.Duration {
 	// more runs for better statistics
 	initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
 
-	if enable {
-		// set thresholds to use basicSqr at this number of words
+	switch mode {
+	case sqrModeMul:
+		basicSqrThreshold = words + 1
+	case sqrModeBasic:
 		basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
-	} else {
-		// set thresholds to disable basicSqr for any number of words
-		basicSqrThreshold, karatsubaSqrThreshold = -1, -1
+	case sqrModeKaratsuba:
+		karatsubaSqrThreshold = words - 1
 	}
 
 	var testval int64
@@ -133,18 +142,18 @@
 	return time.Duration(testval)
 }
 
-func computeSqrThreshold(from, to, step, nruns int) int {
-	fmt.Println("Calibrating thresholds for basicSqr via benchmarks of z.mul(x,x)")
+func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
+	fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
 	fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
 	var initPos bool
 	var threshold int
 	for i := from; i <= to; i += step {
-		baseline := measureBasicSqr(i, nruns, false)
-		testval := measureBasicSqr(i, nruns, true)
+		baseline := measureSqr(i, nruns, lower)
+		testval := measureSqr(i, nruns, upper)
 		pos := baseline > testval
 		delta := baseline - testval
 		percent := delta * 100 / baseline
-		fmt.Printf("words = %3d deltaT = %10s (%4d%%) is basicSqr better: %v", i, delta, percent, pos)
+		fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
 		if i == from {
 			initPos = pos
 		}
diff --git a/src/math/big/nat.go b/src/math/big/nat.go
index 9ec8127..dc292b4 100644
--- a/src/math/big/nat.go
+++ b/src/math/big/nat.go
@@ -388,12 +388,12 @@
 }
 
 // karatsubaLen computes an approximation to the maximum k <= n such that
-// k = p<<i for a number p <= karatsubaThreshold and an i >= 0. Thus, the
+// k = p<<i for a number p <= threshold and an i >= 0. Thus, the
 // result is the largest number that can be divided repeatedly by 2 before
-// becoming about the value of karatsubaThreshold.
-func karatsubaLen(n int) int {
+// becoming about the value of threshold.
+func karatsubaLen(n, threshold int) int {
 	i := uint(0)
-	for n > karatsubaThreshold {
+	for n > threshold {
 		n >>= 1
 		i++
 	}
@@ -433,7 +433,7 @@
 	//   y = yh*b + y0  (0 <= y0 < b)
 	//   b = 1<<(_W*k)  ("base" of digits xi, yi)
 	//
-	k := karatsubaLen(n)
+	k := karatsubaLen(n, karatsubaThreshold)
 	// k <= n
 
 	// multiply x0 and y0 via Karatsuba
@@ -486,8 +486,8 @@
 
 // basicSqr sets z = x*x and is asymptotically faster than basicMul
 // by about a factor of 2, but slower for small arguments due to overhead.
-// Requirements: len(x) > 0, len(z) >= 2*len(x)
-// The (non-normalized) result is placed in z[0 : 2 * len(x)].
+// Requirements: len(x) > 0, len(z) == 2*len(x)
+// The (non-normalized) result is placed in z.
 func basicSqr(z, x nat) {
 	n := len(x)
 	t := make(nat, 2*n)            // temporary variable to hold the products
@@ -503,11 +503,48 @@
 	addVV(z, z, t)                              // combine the result
 }
 
+// karatsubaSqr squares x and leaves the result in z.
+// len(x) must be a power of 2 and len(z) >= 6*len(x).
+// The (non-normalized) result is placed in z[0 : 2*len(x)].
+//
+// The algorithm and the layout of z are the same as for karatsuba.
+func karatsubaSqr(z, x nat) {
+	n := len(x)
+
+	if n&1 != 0 || n < karatsubaSqrThreshold || n < 2 {
+		z = z[:2*n]
+		basicSqr(z, x)
+		return
+	}
+
+	n2 := n >> 1
+	x1, x0 := x[n2:], x[0:n2]
+
+	karatsubaSqr(z, x0)
+	karatsubaSqr(z[n:], x1)
+
+	// s = sign(xd*yd) == -1 for xd != 0; s == 1 for xd == 0
+	xd := z[2*n : 2*n+n2]
+	if subVV(xd, x1, x0) != 0 {
+		subVV(xd, x0, x1)
+	}
+
+	p := z[n*3:]
+	karatsubaSqr(p, xd)
+
+	r := z[n*4:]
+	copy(r, z[:n*2])
+
+	karatsubaAdd(z[n2:], r, n)
+	karatsubaAdd(z[n2:], r[n:], n)
+	karatsubaSub(z[n2:], p, n) // s == -1 for p != 0; s == 1 for p == 0
+}
+
 // Operands that are shorter than basicSqrThreshold are squared using
 // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
-// the Karatsuba algorithm is used.
+// we use the Karatsuba algorithm optimized for x == y.
 var basicSqrThreshold = 20      // computed by calibrate_test.go
-var karatsubaSqrThreshold = 400 // computed by calibrate_test.go
+var karatsubaSqrThreshold = 260 // computed by calibrate_test.go
 
 // z = x*x
 func (z nat) sqr(x nat) nat {
@@ -536,7 +573,31 @@
 		return z.norm()
 	}
 
-	return z.mul(x, x)
+	// Use Karatsuba multiplication optimized for x == y.
+	// The algorithm and layout of z are the same as for mul.
+
+	// z = (x1*b + x0)^2 = x1^2*b^2 + 2*x1*x0*b + x0^2
+
+	k := karatsubaLen(n, karatsubaSqrThreshold)
+
+	x0 := x[0:k]
+	z = z.make(max(6*k, 2*n))
+	karatsubaSqr(z, x0) // z = x0^2
+	z = z[0 : 2*n]
+	z[2*k:].clear()
+
+	if k < n {
+		var t nat
+		x0 := x0.norm()
+		x1 := x[k:]
+		t = t.mul(x0, x1)
+		addAt(z, t, k)
+		addAt(z, t, k) // z = 2*x1*x0*b + x0^2
+		t = t.sqr(x1)
+		addAt(z, t, 2*k) // z = x1^2*b^2 + 2*x1*x0*b + x0^2
+	}
+
+	return z.norm()
 }
 
 // mulRange computes the product of all the unsigned integers in the
diff --git a/src/math/big/nat_test.go b/src/math/big/nat_test.go
index 0b94db3..3c79495 100644
--- a/src/math/big/nat_test.go
+++ b/src/math/big/nat_test.go
@@ -648,26 +648,26 @@
 	}
 }
 
-func testBasicSqr(t *testing.T, x nat) {
+func testSqr(t *testing.T, x nat) {
 	got := make(nat, 2*len(x))
 	want := make(nat, 2*len(x))
-	basicSqr(got, x)
-	basicMul(want, x, x)
+	got = got.sqr(x)
+	want = want.mul(x, x)
 	if got.cmp(want) != 0 {
 		t.Errorf("basicSqr(%v), got %v, want %v", x, got, want)
 	}
 }
 
-func TestBasicSqr(t *testing.T) {
+func TestSqr(t *testing.T) {
 	for _, a := range prodNN {
 		if a.x != nil {
-			testBasicSqr(t, a.x)
+			testSqr(t, a.x)
 		}
 		if a.y != nil {
-			testBasicSqr(t, a.y)
+			testSqr(t, a.y)
 		}
 		if a.z != nil {
-			testBasicSqr(t, a.z)
+			testSqr(t, a.z)
 		}
 	}
 }