blob: 913e0165b0162b781cd11a8b4708d4529b0f4b36 [file] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"crypto/internal/fips140/sha3"
"encoding/hex"
"fmt"
"math/big"
"testing"
)
type interestingValue struct {
v uint32
m fieldElement
}
// q is large enough that we can't exhaustively test all q × q inputs, so when
// we have two inputs we test [0, q) on one side and a set of interesting
// values on the other side.
func interestingValues() []interestingValue {
if testing.Short() {
return []interestingValue{{v: q - 1, m: minusOne}}
}
var values []interestingValue
for _, v := range []uint32{
0,
1,
2,
3,
q - 3,
q - 2,
q - 1,
q / 2,
(q + 1) / 2,
} {
m, _ := fieldToMontgomery(v)
values = append(values, interestingValue{v: v, m: m})
// Also test values that have an interesting Montgomery representation.
values = append(values, interestingValue{
v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)})
}
return values
}
func TestToFromMontgomery(t *testing.T) {
for a := range uint32(q) {
m, err := fieldToMontgomery(a)
if err != nil {
t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err)
}
exp := fieldElement((uint64(a) * R) % q)
if m != exp {
t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp)
}
got := fieldFromMontgomery(m)
if got != a {
t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a)
}
}
}
func TestFieldAdd(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldAdd(a.m, b)
exp := (a.m + b) % q
if got != exp {
t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldSub(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldSub(a.m, b)
exp := (a.m + q - b) % q
if got != exp {
t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldSubToMontgomery(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range uint32(q) {
got := fieldSubToMontgomery(a.v, b)
diff := (a.v + q - b) % q
exp := fieldElement((uint64(diff) * R) % q)
if got != exp {
t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp)
}
}
}
}
func TestFieldReduceOnce(t *testing.T) {
t.Parallel()
for a := range uint32(2 * q) {
got := fieldReduceOnce(a)
var exp uint32
if a < q {
exp = a
} else {
exp = a - q
}
if uint32(got) != exp {
t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp)
}
}
}
func TestFieldMul(t *testing.T) {
t.Parallel()
for _, a := range interestingValues() {
for b := range fieldElement(q) {
got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b))
exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q)
if got != exp {
t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp)
}
}
}
}
func TestFieldToMontgomeryOverflow(t *testing.T) {
// fieldToMontgomery should reject inputs ≥ q.
inputs := []uint32{
q,
q + 1,
q + 2,
1<<23 - 1,
1 << 23,
q + 1<<23,
q + 1<<31,
^uint32(0),
}
for _, in := range inputs {
if _, err := fieldToMontgomery(in); err == nil {
t.Fatalf("fieldToMontgomery(%d) did not return an error", in)
}
}
}
func TestFieldMulSub(t *testing.T) {
for _, a := range interestingValues() {
for _, b := range interestingValues() {
for _, c := range interestingValues() {
got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m))
exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q)
if got != exp {
t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp)
}
}
}
}
}
func TestFieldAddMul(t *testing.T) {
for _, a := range interestingValues() {
for _, b := range interestingValues() {
for _, c := range interestingValues() {
for _, d := range interestingValues() {
got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m))
exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q)
if got != exp {
t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp)
}
}
}
}
}
}
func BitRev8(n uint8) uint8 {
var r uint8
r |= n >> 7 & 0b0000_0001
r |= n >> 5 & 0b0000_0010
r |= n >> 3 & 0b0000_0100
r |= n >> 1 & 0b0000_1000
r |= n << 1 & 0b0001_0000
r |= n << 3 & 0b0010_0000
r |= n << 5 & 0b0100_0000
r |= n << 7 & 0b1000_0000
return r
}
func CenteredMod(x, m uint32) int32 {
x = x % m
if x > m/2 {
return int32(x) - int32(m)
}
return int32(x)
}
func reduceModQ(x int32) uint32 {
x %= q
if x < 0 {
return uint32(x + q)
}
return uint32(x)
}
func TestCenteredMod(t *testing.T) {
for x := range uint32(q * 2) {
got := CenteredMod(uint32(x), q)
if reduceModQ(got) != (x % q) {
t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q)
}
}
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
got := fieldCenteredMod(r)
exp := CenteredMod(x, q)
if got != exp {
t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp)
}
}
}
func TestInfinityNorm(t *testing.T) {
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
got := fieldInfinityNorm(r)
exp := CenteredMod(x, q)
if exp < 0 {
exp = -exp
}
if got != uint32(exp) {
t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp)
}
}
}
func TestConstants(t *testing.T) {
if fieldFromMontgomery(one) != 1 {
t.Errorf("one constant incorrect")
}
if fieldFromMontgomery(minusOne) != q-1 {
t.Errorf("minusOne constant incorrect")
}
if fieldInfinityNorm(one) != 1 {
t.Errorf("one infinity norm incorrect")
}
if fieldInfinityNorm(minusOne) != 1 {
t.Errorf("minusOne infinity norm incorrect")
}
if PublicKeySize44 != pubKeySize(params44) {
t.Errorf("PublicKeySize44 constant incorrect")
}
if PublicKeySize65 != pubKeySize(params65) {
t.Errorf("PublicKeySize65 constant incorrect")
}
if PublicKeySize87 != pubKeySize(params87) {
t.Errorf("PublicKeySize87 constant incorrect")
}
if SignatureSize44 != sigSize(params44) {
t.Errorf("SignatureSize44 constant incorrect")
}
if SignatureSize65 != sigSize(params65) {
t.Errorf("SignatureSize65 constant incorrect")
}
if SignatureSize87 != sigSize(params87) {
t.Errorf("SignatureSize87 constant incorrect")
}
}
func TestPower2Round(t *testing.T) {
t.Parallel()
for x := range uint32(q) {
rr, _ := fieldToMontgomery(x)
t1, t0 := power2Round(rr)
hi, err := fieldToMontgomery(uint32(t1) << 13)
if err != nil {
t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err)
}
if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x {
t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x)
}
}
}
func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) {
r := fieldFromMontgomery(rr)
if (q-1)%p.γ2 != 0 {
panic("mldsa: internal error: unsupported denγ2")
}
γ2 := (q - 1) / uint32(p.γ2)
r0 := CenteredMod(r, 2*γ2)
diff := int32(r) - r0
if diff == q-1 {
r0 = r0 - 1
return 0, r0
} else {
if diff < 0 || uint32(diff)%γ2 != 0 {
panic("mldsa: internal error: invalid decomposition")
}
r1 := uint32(diff) / (2 * γ2)
return r1, r0
}
}
func TestDecompose(t *testing.T) {
t.Run("ML-DSA-44", func(t *testing.T) {
testDecompose(t, params44)
})
t.Run("ML-DSA-65,87", func(t *testing.T) {
testDecompose(t, params65)
})
}
func testDecompose(t *testing.T, p parameters) {
t.Parallel()
for x := range uint32(q) {
rr, _ := fieldToMontgomery(x)
r1, r0 := SpecDecompose(rr, p)
// Check that SpecDecompose is correct.
// r ≡ r1 * (2 * γ2) + r0 mod q
γ2 := (q - 1) / uint32(p.γ2)
reconstructed := reduceModQ(int32(r1*2*γ2) + r0)
if reconstructed != x {
t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x)
}
var gotR1 byte
var gotR0 int32
switch p.γ2 {
case 88:
gotR1, gotR0 = decompose88(rr)
if gotR1 > 43 {
t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1)
}
case 32:
gotR1, gotR0 = decompose32(rr)
if gotR1 > 15 {
t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1)
}
default:
t.Fatalf("unsupported denγ2: %d", p.γ2)
}
if uint32(gotR1) != r1 {
t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1)
}
if gotR0 != r0 {
t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0)
}
}
}
func TestZetas(t *testing.T) {
ζ := big.NewInt(1753)
q := big.NewInt(q)
for k, zeta := range zetas {
// ζ^BitRev₈(k) mod q
exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q)
got := fieldFromMontgomery(zeta)
if big.NewInt(int64(got)).Cmp(exp) != 0 {
t.Errorf("zetas[%d] = %v, expected %v", k, got, exp)
}
}
}
// TestAccumulated computes the hash of the following 12 values, as ASCII
// decimals with an optional leading - sign and separated by newlines, for all
// elements r in ℤq from 0 to q-1:
//
// - r mod± q
// - ‖r‖∞ = |r mod± q|
// - r1, r0 = Power2Round(r)
//
// For ML-DSA-44 (γ₂ = (q - 1) / 88):
// - HighBits(r) = UseHint(0, r)
// - UseHint(1, r)
// - LowBits(r)
// - ‖LowBits(r)‖∞ = |LowBits(r)|
//
// For ML-DSA-65 and ML-DSA-87 (γ₂ = (q - 1) / 32):
// - HighBits(r) = UseHint(0, r)
// - UseHint(1, r)
// - LowBits(r)
// - ‖LowBits(r)‖∞ = |LowBits(r)|
//
// Note that HighBits(r), LowBits(r) = Decompose(r).
func TestAccumulated(t *testing.T) {
if testing.Short() {
t.Skip("skipping accumulated test in short mode")
}
o := sha3.NewShake128()
for x := range uint32(q) {
r, _ := fieldToMontgomery(x)
fmt.Fprintf(o, "%d\n", fieldCenteredMod(r))
fmt.Fprintf(o, "%d\n", fieldInfinityNorm(r))
hi, lo := power2Round(r)
fmt.Fprintf(o, "%d\n", hi)
fmt.Fprintf(o, "%d\n", fieldFromMontgomery(lo))
r1, r0 := decompose88(r)
if r1x := highBits88(fieldFromMontgomery(r)); r1x != r1 {
t.Fatalf("highBits88(%d) = %d, expected %d", x, r1x, r1)
}
if r1h0 := useHint88(r, 0); r1h0 != r1 {
t.Fatalf("useHint88(%d, 0) = %d, expected %d", x, r1h0, r1)
}
fmt.Fprintf(o, "%d\n", r1)
fmt.Fprintf(o, "%d\n", useHint88(r, 1))
fmt.Fprintf(o, "%d\n", r0)
fmt.Fprintf(o, "%d\n", constantTimeAbs(r0))
r1, r0 = decompose32(r)
if r1x := highBits32(fieldFromMontgomery(r)); r1x != r1 {
t.Fatalf("highBits32(%d) = %d, expected %d", x, r1x, r1)
}
if r1h0 := useHint32(r, 0); r1h0 != r1 {
t.Fatalf("useHint32(%d, 0) = %d, expected %d", x, r1h0, r1)
}
fmt.Fprintf(o, "%d\n", r1)
fmt.Fprintf(o, "%d\n", useHint32(r, 1))
fmt.Fprintf(o, "%d\n", r0)
fmt.Fprintf(o, "%d\n", constantTimeAbs(r0))
}
// The expected value is documented at https://c2sp.org/CCTV/ML-DSA, and
// tested against https://github.com/FiloSottile/mldsa-py.
expected := "f930663417278156ab05d940294a77210a809c924d8ab63ec72f4526247602c7"
if got := hex.EncodeToString(o.Sum(nil)); got != expected {
t.Errorf("got %s, expected %s", got, expected)
}
}