crypto/elliptic: fix P-224 field reduction

This patch fixes two independent bugs in p224Contract, the function that
performs the final complete reduction in the P-224 field. Incorrect
outputs due to these bugs were observable from a high-level
P224().ScalarMult() call.

The first bug was in the calculation of out3GT. That mask was supposed
to be all ones if the third limb of the value is greater than the third
limb of P (out[3] > 0xffff000). Instead, it was also set if they are
equal. That meant that if the third limb was equal, the value was always
considered greater than or equal to P, even when the three bottom limbs
were all zero. There is exactly one affected value, P - 1, which would
trigger the subtraction by P even if it's lower than P already.

The second bug was more easily hit, and is the one that caused the known
high-level incorrect output: after the conditional subtraction by P, a
potential underflow of the lowest limb was not handled. Any values that
trigger the subtraction by P (values between P and 2^224-1, and P - 1
due to the bug above) but have a zero lowest limb would produce invalid
outputs. Those conditions apply to the intermediate representation
before the subtraction, so they are hard to trace to precise inputs.

This patch also adds a test suite for the P-224 field arithmetic,
including a custom fuzzer that automatically explores potential edge
cases by combining limb values that have various meanings in the code.
contractMatchesBigInt in TestP224Contract finds the second bug in less
than a second without being tailored to it, and could eventually find
the first one too by combining 0, (1 << 28) - 1, and the difference of
(1 << 28) and (1 << 12).

The incorrect P224().ScalarMult() output was found by the
elliptic-curve-differential-fuzzer project running on OSS-Fuzz and
reported by Philippe Antoine (Catena cyber).

Fixes CVE-2021-3114
Fixes #43786

Change-Id: I50176602d544de3da854270d66a293bcaca57ad7
Reviewed-on: https://go-review.googlesource.com/c/go/+/284779
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Ian Lance Taylor <iant@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/src/crypto/elliptic/p224.go b/src/crypto/elliptic/p224.go
index 2ea63f3..8c76021 100644
--- a/src/crypto/elliptic/p224.go
+++ b/src/crypto/elliptic/p224.go
@@ -386,10 +386,11 @@
 // p224Contract converts a FieldElement to its unique, minimal form.
 //
 // On entry, in[i] < 2**29
-// On exit, in[i] < 2**28
+// On exit, out[i] < 2**28 and out < p
 func p224Contract(out, in *p224FieldElement) {
 	copy(out[:], in[:])
 
+	// First, carry the bits above 28 to the higher limb.
 	for i := 0; i < 7; i++ {
 		out[i+1] += out[i] >> 28
 		out[i] &= bottom28Bits
@@ -397,10 +398,13 @@
 	top := out[7] >> 28
 	out[7] &= bottom28Bits
 
+	// Use the reduction identity to carry the overflow.
+	//
+	//   a + top * 2²²⁴ = a + top * 2⁹⁶ - top
 	out[0] -= top
 	out[3] += top << 12
 
-	// We may just have made out[i] negative. So we carry down. If we made
+	// We may just have made out[0] negative. So we carry down. If we made
 	// out[0] negative then we know that out[3] is sufficiently positive
 	// because we just added to it.
 	for i := 0; i < 3; i++ {
@@ -425,13 +429,12 @@
 	// There are two cases to consider for out[3]:
 	//   1) The first time that we eliminated top, we didn't push out[3] over
 	//      2**28. In this case, the partial carry chain didn't change any values
-	//      and top is zero.
+	//      and top is now zero.
 	//   2) We did push out[3] over 2**28 the first time that we eliminated top.
-	//      The first value of top was in [0..16), therefore, prior to eliminating
-	//      the first top, 0xfff1000 <= out[3] <= 0xfffffff. Therefore, after
-	//      overflowing and being reduced by the second carry chain, out[3] <=
-	//      0xf000. Thus it cannot have overflowed when we eliminated top for the
-	//      second time.
+	//      The first value of top was in [0..2], therefore, after overflowing
+	//      and being reduced by the second carry chain, out[3] <= 2<<12 - 1.
+	// In both cases, out[3] cannot have overflowed when we eliminated top for
+	// the second time.
 
 	// Again, we may just have made out[0] negative, so do the same carry down.
 	// As before, if we made out[0] negative then we know that out[3] is
@@ -470,12 +473,11 @@
 	bottom3NonZero |= bottom3NonZero >> 1
 	bottom3NonZero = uint32(int32(bottom3NonZero<<31) >> 31)
 
-	// Everything depends on the value of out[3].
-	//    If it's > 0xffff000 and top4AllOnes != 0 then the whole value is >= p
-	//    If it's = 0xffff000 and top4AllOnes != 0 and bottom3NonZero != 0,
-	//      then the whole value is >= p
+	// Assuming top4AllOnes != 0, everything depends on the value of out[3].
+	//    If it's > 0xffff000 then the whole value is > p
+	//    If it's = 0xffff000 and bottom3NonZero != 0, then the whole value is >= p
 	//    If it's < 0xffff000, then the whole value is < p
-	n := out[3] - 0xffff000
+	n := 0xffff000 - out[3]
 	out3Equal := n
 	out3Equal |= out3Equal >> 16
 	out3Equal |= out3Equal >> 8
@@ -484,8 +486,8 @@
 	out3Equal |= out3Equal >> 1
 	out3Equal = ^uint32(int32(out3Equal<<31) >> 31)
 
-	// If out[3] > 0xffff000 then n's MSB will be zero.
-	out3GT := ^uint32(int32(n) >> 31)
+	// If out[3] > 0xffff000 then n's MSB will be one.
+	out3GT := uint32(int32(n) >> 31)
 
 	mask := top4AllOnes & ((out3Equal & bottom3NonZero) | out3GT)
 	out[0] -= 1 & mask
@@ -494,6 +496,15 @@
 	out[5] -= 0xfffffff & mask
 	out[6] -= 0xfffffff & mask
 	out[7] -= 0xfffffff & mask
+
+	// Do one final carry down, in case we made out[0] negative. One of
+	// out[0..3] needs to be positive and able to absorb the -1 or the value
+	// would have been < p, and the subtraction wouldn't have happened.
+	for i := 0; i < 3; i++ {
+		mask := uint32(int32(out[i]) >> 31)
+		out[i] += (1 << 28) & mask
+		out[i+1] -= 1 & mask
+	}
 }
 
 // Group element functions.
diff --git a/src/crypto/elliptic/p224_test.go b/src/crypto/elliptic/p224_test.go
index 8b4fa04..c3141b6 100644
--- a/src/crypto/elliptic/p224_test.go
+++ b/src/crypto/elliptic/p224_test.go
@@ -6,7 +6,11 @@
 
 import (
 	"math/big"
+	"math/bits"
+	"math/rand"
+	"reflect"
 	"testing"
+	"testing/quick"
 )
 
 var toFromBigTests = []string{
@@ -21,16 +25,16 @@
 	ret := new(big.Int)
 	tmp := new(big.Int)
 
-	for i := uint(0); i < 8; i++ {
+	for i := len(in) - 1; i >= 0; i-- {
+		ret.Lsh(ret, 28)
 		tmp.SetInt64(int64(in[i]))
-		tmp.Lsh(tmp, 28*i)
 		ret.Add(ret, tmp)
 	}
-	ret.Mod(ret, p224.P)
+	ret.Mod(ret, P224().Params().P)
 	return ret
 }
 
-func TestToFromBig(t *testing.T) {
+func TestP224ToFromBig(t *testing.T) {
 	for i, test := range toFromBigTests {
 		n, _ := new(big.Int).SetString(test, 16)
 		var x p224FieldElement
@@ -41,7 +45,270 @@
 		}
 		q := p224AlternativeToBig(&x)
 		if n.Cmp(q) != 0 {
-			t.Errorf("#%d: %x != %x (alternative)", i, n, m)
+			t.Errorf("#%d: %x != %x (alternative)", i, n, q)
 		}
 	}
 }
+
+// quickCheckConfig32 will make each quickcheck test run (32 * -quickchecks)
+// times. The default value of -quickchecks is 100.
+var quickCheckConfig32 = &quick.Config{MaxCountScale: 32}
+
+// weirdLimbs can be combined to generate a range of edge-case field elements.
+var weirdLimbs = [...]uint32{
+	0, 1, (1 << 29) - 1,
+	(1 << 12), (1 << 12) - 1,
+	(1 << 28), (1 << 28) - 1,
+}
+
+func generateLimb(rand *rand.Rand) uint32 {
+	const bottom29Bits = 0x1fffffff
+	n := rand.Intn(len(weirdLimbs) + 3)
+	switch n {
+	case len(weirdLimbs):
+		// Random value.
+		return uint32(rand.Int31n(1 << 29))
+	case len(weirdLimbs) + 1:
+		// Sum of two values.
+		k := generateLimb(rand) + generateLimb(rand)
+		return k & bottom29Bits
+	case len(weirdLimbs) + 2:
+		// Difference of two values.
+		k := generateLimb(rand) - generateLimb(rand)
+		return k & bottom29Bits
+	default:
+		return weirdLimbs[n]
+	}
+}
+
+func (p224FieldElement) Generate(rand *rand.Rand, size int) reflect.Value {
+	return reflect.ValueOf(p224FieldElement{
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+		generateLimb(rand),
+	})
+}
+
+func isInBounds(x *p224FieldElement) bool {
+	return bits.Len32(x[0]) <= 29 &&
+		bits.Len32(x[1]) <= 29 &&
+		bits.Len32(x[2]) <= 29 &&
+		bits.Len32(x[3]) <= 29 &&
+		bits.Len32(x[4]) <= 29 &&
+		bits.Len32(x[5]) <= 29 &&
+		bits.Len32(x[6]) <= 29 &&
+		bits.Len32(x[7]) <= 29
+}
+
+func TestP224Mul(t *testing.T) {
+	mulMatchesBigInt := func(a, b, out p224FieldElement) bool {
+		var tmp p224LargeFieldElement
+		p224Mul(&out, &a, &b, &tmp)
+
+		exp := new(big.Int).Mul(p224AlternativeToBig(&a), p224AlternativeToBig(&b))
+		exp.Mod(exp, P224().Params().P)
+		got := p224AlternativeToBig(&out)
+		if exp.Cmp(got) != 0 || !isInBounds(&out) {
+			t.Logf("a = %x", a)
+			t.Logf("b = %x", b)
+			t.Logf("p224Mul(a, b) = %x = %v", out, got)
+			t.Logf("a * b = %v", exp)
+			return false
+		}
+
+		return true
+	}
+
+	a := p224FieldElement{0xfffffff, 0xfffffff, 0xf00ffff, 0x20f, 0x0, 0x0, 0x0, 0x0}
+	b := p224FieldElement{1, 0, 0, 0, 0, 0, 0, 0}
+	if !mulMatchesBigInt(a, b, p224FieldElement{}) {
+		t.Fail()
+	}
+
+	if err := quick.Check(mulMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224Square(t *testing.T) {
+	squareMatchesBigInt := func(a, out p224FieldElement) bool {
+		var tmp p224LargeFieldElement
+		p224Square(&out, &a, &tmp)
+
+		exp := p224AlternativeToBig(&a)
+		exp.Mul(exp, exp)
+		exp.Mod(exp, P224().Params().P)
+		got := p224AlternativeToBig(&out)
+		if exp.Cmp(got) != 0 || !isInBounds(&out) {
+			t.Logf("a = %x", a)
+			t.Logf("p224Square(a, b) = %x = %v", out, got)
+			t.Logf("a * a = %v", exp)
+			return false
+		}
+
+		return true
+	}
+
+	if err := quick.Check(squareMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224Add(t *testing.T) {
+	addMatchesBigInt := func(a, b, out p224FieldElement) bool {
+		p224Add(&out, &a, &b)
+
+		exp := new(big.Int).Add(p224AlternativeToBig(&a), p224AlternativeToBig(&b))
+		exp.Mod(exp, P224().Params().P)
+		got := p224AlternativeToBig(&out)
+		if exp.Cmp(got) != 0 {
+			t.Logf("a = %x", a)
+			t.Logf("b = %x", b)
+			t.Logf("p224Add(a, b) = %x = %v", out, got)
+			t.Logf("a + b = %v", exp)
+			return false
+		}
+
+		return true
+	}
+
+	if err := quick.Check(addMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224Reduce(t *testing.T) {
+	reduceMatchesBigInt := func(a p224FieldElement) bool {
+		out := a
+		// TODO: generate higher values for functions like p224Reduce that are
+		// expected to work with higher input bounds.
+		p224Reduce(&out)
+
+		exp := p224AlternativeToBig(&a)
+		got := p224AlternativeToBig(&out)
+		if exp.Cmp(got) != 0 || !isInBounds(&out) {
+			t.Logf("a = %x = %v", a, exp)
+			t.Logf("p224Reduce(a) = %x = %v", out, got)
+			return false
+		}
+
+		return true
+	}
+
+	if err := quick.Check(reduceMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224Contract(t *testing.T) {
+	contractMatchesBigInt := func(a, out p224FieldElement) bool {
+		p224Contract(&out, &a)
+
+		exp := p224AlternativeToBig(&a)
+		got := p224AlternativeToBig(&out)
+		if exp.Cmp(got) != 0 {
+			t.Logf("a = %x = %v", a, exp)
+			t.Logf("p224Contract(a) = %x = %v", out, got)
+			return false
+		}
+
+		// Check that out < P.
+		for i := range p224P {
+			k := 8 - i - 1
+			if out[k] > p224P[k] {
+				t.Logf("p224Contract(a) = %x", out)
+				return false
+			}
+			if out[k] < p224P[k] {
+				return true
+			}
+		}
+		t.Logf("p224Contract(a) = %x", out)
+		return false
+	}
+
+	if !contractMatchesBigInt(p224P, p224FieldElement{}) {
+		t.Error("p224Contract(p) is broken")
+	}
+	pMinus1 := p224FieldElement{0, 0, 0, 0xffff000, 0xfffffff, 0xfffffff, 0xfffffff, 0xfffffff}
+	if !contractMatchesBigInt(pMinus1, p224FieldElement{}) {
+		t.Error("p224Contract(p - 1) is broken")
+	}
+	// Check that we can handle input above p, but lowest limb zero.
+	a := p224FieldElement{0, 1, 0, 0xffff000, 0xfffffff, 0xfffffff, 0xfffffff, 0xfffffff}
+	if !contractMatchesBigInt(a, p224FieldElement{}) {
+		t.Error("p224Contract(p + 2²⁸) is broken")
+	}
+	// Check that we can handle input above p, but lowest three limbs zero.
+	b := p224FieldElement{0, 0, 0, 0xffff001, 0xfffffff, 0xfffffff, 0xfffffff, 0xfffffff}
+	if !contractMatchesBigInt(b, p224FieldElement{}) {
+		t.Error("p224Contract(p + 2⁸⁴) is broken")
+	}
+
+	if err := quick.Check(contractMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224IsZero(t *testing.T) {
+	if got := p224IsZero(&p224FieldElement{}); got != 1 {
+		t.Errorf("p224IsZero(0) = %d, expected 1", got)
+	}
+	if got := p224IsZero((*p224FieldElement)(&p224P)); got != 1 {
+		t.Errorf("p224IsZero(p) = %d, expected 1", got)
+	}
+	if got := p224IsZero(&p224FieldElement{1}); got != 0 {
+		t.Errorf("p224IsZero(1) = %d, expected 0", got)
+	}
+
+	isZeroMatchesBigInt := func(a p224FieldElement) bool {
+		isZero := p224IsZero(&a)
+
+		big := p224AlternativeToBig(&a)
+		if big.Sign() == 0 && isZero != 1 {
+			return false
+		}
+		if big.Sign() != 0 && isZero != 0 {
+			return false
+		}
+		return true
+	}
+
+	if err := quick.Check(isZeroMatchesBigInt, quickCheckConfig32); err != nil {
+		t.Error(err)
+	}
+}
+
+func TestP224Invert(t *testing.T) {
+	var out p224FieldElement
+
+	p224Invert(&out, &p224FieldElement{})
+	if got := p224IsZero(&out); got != 1 {
+		t.Errorf("p224Invert(0) = %x, expected 0", out)
+	}
+
+	p224Invert(&out, (*p224FieldElement)(&p224P))
+	if got := p224IsZero(&out); got != 1 {
+		t.Errorf("p224Invert(p) = %x, expected 0", out)
+	}
+
+	p224Invert(&out, &p224FieldElement{1})
+	p224Contract(&out, &out)
+	if out != (p224FieldElement{1}) {
+		t.Errorf("p224Invert(1) = %x, expected 1", out)
+	}
+
+	var tmp p224LargeFieldElement
+	a := p224FieldElement{1, 2, 3, 4, 5, 6, 7, 8}
+	p224Invert(&out, &a)
+	p224Mul(&out, &out, &a, &tmp)
+	p224Contract(&out, &out)
+	if out != (p224FieldElement{1}) {
+		t.Errorf("p224Invert(a) * a = %x, expected 1", out)
+	}
+}