blob: 1bbdb667d2eddaafe2b5b729f26c8b944df7968b [file] [log] [blame]
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ecdsa
import (
"bytes"
"crypto/internal/fips140/bigmod"
"crypto/rand"
"io"
"testing"
)
func TestRandomPoint(t *testing.T) {
t.Run("P-224", func(t *testing.T) { testRandomPoint(t, P224()) })
t.Run("P-256", func(t *testing.T) { testRandomPoint(t, P256()) })
t.Run("P-384", func(t *testing.T) { testRandomPoint(t, P384()) })
t.Run("P-521", func(t *testing.T) { testRandomPoint(t, P521()) })
}
func testRandomPoint[P Point[P]](t *testing.T, c *Curve[P]) {
t.Cleanup(func() { testingOnlyRejectionSamplingLooped = nil })
var loopCount int
testingOnlyRejectionSamplingLooped = func() { loopCount++ }
// A sequence of all ones will generate 2^N-1, which should be rejected.
// (Unless, for example, we are masking too many bits.)
r := io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0xff}, 100)), rand.Reader)
if k, p, err := randomPoint(c, func(b []byte) error {
_, err := r.Read(b)
return err
}); err != nil {
t.Fatal(err)
} else if k.IsZero() == 1 {
t.Error("k is zero")
} else if p.Bytes()[0] != 4 {
t.Error("p is infinity")
}
if loopCount == 0 {
t.Error("overflow was not rejected")
}
loopCount = 0
// A sequence of all zeroes will generate zero, which should be rejected.
r = io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0}, 100)), rand.Reader)
if k, p, err := randomPoint(c, func(b []byte) error {
_, err := r.Read(b)
return err
}); err != nil {
t.Fatal(err)
} else if k.IsZero() == 1 {
t.Error("k is zero")
} else if p.Bytes()[0] != 4 {
t.Error("p is infinity")
}
if loopCount == 0 {
t.Error("zero was not rejected")
}
loopCount = 0
// P-256 has a 2⁻³² chance of randomly hitting a rejection. For P-224 it's
// 2⁻¹¹², for P-384 it's 2⁻¹⁹⁴, and for P-521 it's 2⁻²⁶², so if we hit in
// tests, something is horribly wrong. (For example, we are masking the
// wrong bits.)
if c.curve == p256 {
return
}
if k, p, err := randomPoint(c, func(b []byte) error {
_, err := rand.Reader.Read(b)
return err
}); err != nil {
t.Fatal(err)
} else if k.IsZero() == 1 {
t.Error("k is zero")
} else if p.Bytes()[0] != 4 {
t.Error("p is infinity")
}
if loopCount > 0 {
t.Error("unexpected rejection")
}
}
func TestHashToNat(t *testing.T) {
t.Run("P-224", func(t *testing.T) { testHashToNat(t, P224()) })
t.Run("P-256", func(t *testing.T) { testHashToNat(t, P256()) })
t.Run("P-384", func(t *testing.T) { testHashToNat(t, P384()) })
t.Run("P-521", func(t *testing.T) { testHashToNat(t, P521()) })
}
func testHashToNat[P Point[P]](t *testing.T, c *Curve[P]) {
for l := 0; l < 600; l++ {
h := bytes.Repeat([]byte{0xff}, l)
hashToNat(c, bigmod.NewNat(), h)
}
}