| // Copyright 2025 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package mldsa |
| |
| import ( |
| "crypto/internal/fips140/sha3" |
| "encoding/hex" |
| "fmt" |
| "math/big" |
| "testing" |
| ) |
| |
| type interestingValue struct { |
| v uint32 |
| m fieldElement |
| } |
| |
| // q is large enough that we can't exhaustively test all q × q inputs, so when |
| // we have two inputs we test [0, q) on one side and a set of interesting |
| // values on the other side. |
| func interestingValues() []interestingValue { |
| if testing.Short() { |
| return []interestingValue{{v: q - 1, m: minusOne}} |
| } |
| var values []interestingValue |
| for _, v := range []uint32{ |
| 0, |
| 1, |
| 2, |
| 3, |
| q - 3, |
| q - 2, |
| q - 1, |
| q / 2, |
| (q + 1) / 2, |
| } { |
| m, _ := fieldToMontgomery(v) |
| values = append(values, interestingValue{v: v, m: m}) |
| // Also test values that have an interesting Montgomery representation. |
| values = append(values, interestingValue{ |
| v: fieldFromMontgomery(fieldElement(v)), m: fieldElement(v)}) |
| } |
| return values |
| } |
| |
| func TestToFromMontgomery(t *testing.T) { |
| for a := range uint32(q) { |
| m, err := fieldToMontgomery(a) |
| if err != nil { |
| t.Fatalf("fieldToMontgomery(%d) returned error: %v", a, err) |
| } |
| exp := fieldElement((uint64(a) * R) % q) |
| if m != exp { |
| t.Fatalf("fieldToMontgomery(%d) = %d, expected %d", a, m, exp) |
| } |
| got := fieldFromMontgomery(m) |
| if got != a { |
| t.Fatalf("fieldFromMontgomery(fieldToMontgomery(%d)) = %d, expected %d", a, got, a) |
| } |
| } |
| } |
| |
| func TestFieldAdd(t *testing.T) { |
| t.Parallel() |
| for _, a := range interestingValues() { |
| for b := range fieldElement(q) { |
| got := fieldAdd(a.m, b) |
| exp := (a.m + b) % q |
| if got != exp { |
| t.Fatalf("%d + %d = %d, expected %d", a, b, got, exp) |
| } |
| } |
| } |
| } |
| |
| func TestFieldSub(t *testing.T) { |
| t.Parallel() |
| for _, a := range interestingValues() { |
| for b := range fieldElement(q) { |
| got := fieldSub(a.m, b) |
| exp := (a.m + q - b) % q |
| if got != exp { |
| t.Fatalf("%d - %d = %d, expected %d", a, b, got, exp) |
| } |
| } |
| } |
| } |
| |
| func TestFieldSubToMontgomery(t *testing.T) { |
| t.Parallel() |
| for _, a := range interestingValues() { |
| for b := range uint32(q) { |
| got := fieldSubToMontgomery(a.v, b) |
| diff := (a.v + q - b) % q |
| exp := fieldElement((uint64(diff) * R) % q) |
| if got != exp { |
| t.Fatalf("fieldSubToMontgomery(%d, %d) = %d, expected %d", a.v, b, got, exp) |
| } |
| } |
| } |
| } |
| |
| func TestFieldReduceOnce(t *testing.T) { |
| t.Parallel() |
| for a := range uint32(2 * q) { |
| got := fieldReduceOnce(a) |
| var exp uint32 |
| if a < q { |
| exp = a |
| } else { |
| exp = a - q |
| } |
| if uint32(got) != exp { |
| t.Fatalf("fieldReduceOnce(%d) = %d, expected %d", a, got, exp) |
| } |
| } |
| } |
| |
| func TestFieldMul(t *testing.T) { |
| t.Parallel() |
| for _, a := range interestingValues() { |
| for b := range fieldElement(q) { |
| got := fieldFromMontgomery(fieldMontgomeryMul(a.m, b)) |
| exp := uint32((uint64(a.v) * uint64(fieldFromMontgomery(b))) % q) |
| if got != exp { |
| t.Fatalf("%d * %d = %d, expected %d", a, b, got, exp) |
| } |
| } |
| } |
| } |
| |
| func TestFieldToMontgomeryOverflow(t *testing.T) { |
| // fieldToMontgomery should reject inputs ≥ q. |
| inputs := []uint32{ |
| q, |
| q + 1, |
| q + 2, |
| 1<<23 - 1, |
| 1 << 23, |
| q + 1<<23, |
| q + 1<<31, |
| ^uint32(0), |
| } |
| for _, in := range inputs { |
| if _, err := fieldToMontgomery(in); err == nil { |
| t.Fatalf("fieldToMontgomery(%d) did not return an error", in) |
| } |
| } |
| } |
| |
| func TestFieldMulSub(t *testing.T) { |
| for _, a := range interestingValues() { |
| for _, b := range interestingValues() { |
| for _, c := range interestingValues() { |
| got := fieldFromMontgomery(fieldMontgomeryMulSub(a.m, b.m, c.m)) |
| exp := uint32((uint64(a.v) * (uint64(b.v) + q - uint64(c.v))) % q) |
| if got != exp { |
| t.Fatalf("%d * (%d - %d) = %d, expected %d", a.v, b.v, c.v, got, exp) |
| } |
| } |
| } |
| } |
| } |
| |
| func TestFieldAddMul(t *testing.T) { |
| for _, a := range interestingValues() { |
| for _, b := range interestingValues() { |
| for _, c := range interestingValues() { |
| for _, d := range interestingValues() { |
| got := fieldFromMontgomery(fieldMontgomeryAddMul(a.m, b.m, c.m, d.m)) |
| exp := uint32((uint64(a.v)*uint64(b.v) + uint64(c.v)*uint64(d.v)) % q) |
| if got != exp { |
| t.Fatalf("%d + %d * %d = %d, expected %d", a.v, b.v, c.v, got, exp) |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| func BitRev8(n uint8) uint8 { |
| var r uint8 |
| r |= n >> 7 & 0b0000_0001 |
| r |= n >> 5 & 0b0000_0010 |
| r |= n >> 3 & 0b0000_0100 |
| r |= n >> 1 & 0b0000_1000 |
| r |= n << 1 & 0b0001_0000 |
| r |= n << 3 & 0b0010_0000 |
| r |= n << 5 & 0b0100_0000 |
| r |= n << 7 & 0b1000_0000 |
| return r |
| } |
| |
| func CenteredMod(x, m uint32) int32 { |
| x = x % m |
| if x > m/2 { |
| return int32(x) - int32(m) |
| } |
| return int32(x) |
| } |
| |
| func reduceModQ(x int32) uint32 { |
| x %= q |
| if x < 0 { |
| return uint32(x + q) |
| } |
| return uint32(x) |
| } |
| |
| func TestCenteredMod(t *testing.T) { |
| for x := range uint32(q * 2) { |
| got := CenteredMod(uint32(x), q) |
| if reduceModQ(got) != (x % q) { |
| t.Fatalf("CenteredMod(%d) = %d, which is not congruent to %d mod %d", x, got, x, q) |
| } |
| } |
| |
| for x := range uint32(q) { |
| r, _ := fieldToMontgomery(x) |
| got := fieldCenteredMod(r) |
| exp := CenteredMod(x, q) |
| if got != exp { |
| t.Fatalf("fieldCenteredMod(%d) = %d, expected %d", x, got, exp) |
| } |
| } |
| } |
| |
| func TestInfinityNorm(t *testing.T) { |
| for x := range uint32(q) { |
| r, _ := fieldToMontgomery(x) |
| got := fieldInfinityNorm(r) |
| exp := CenteredMod(x, q) |
| if exp < 0 { |
| exp = -exp |
| } |
| if got != uint32(exp) { |
| t.Fatalf("fieldInfinityNorm(%d) = %d, expected %d", x, got, exp) |
| } |
| } |
| } |
| |
| func TestConstants(t *testing.T) { |
| if fieldFromMontgomery(one) != 1 { |
| t.Errorf("one constant incorrect") |
| } |
| if fieldFromMontgomery(minusOne) != q-1 { |
| t.Errorf("minusOne constant incorrect") |
| } |
| if fieldInfinityNorm(one) != 1 { |
| t.Errorf("one infinity norm incorrect") |
| } |
| if fieldInfinityNorm(minusOne) != 1 { |
| t.Errorf("minusOne infinity norm incorrect") |
| } |
| |
| if PublicKeySize44 != pubKeySize(params44) { |
| t.Errorf("PublicKeySize44 constant incorrect") |
| } |
| if PublicKeySize65 != pubKeySize(params65) { |
| t.Errorf("PublicKeySize65 constant incorrect") |
| } |
| if PublicKeySize87 != pubKeySize(params87) { |
| t.Errorf("PublicKeySize87 constant incorrect") |
| } |
| if SignatureSize44 != sigSize(params44) { |
| t.Errorf("SignatureSize44 constant incorrect") |
| } |
| if SignatureSize65 != sigSize(params65) { |
| t.Errorf("SignatureSize65 constant incorrect") |
| } |
| if SignatureSize87 != sigSize(params87) { |
| t.Errorf("SignatureSize87 constant incorrect") |
| } |
| } |
| |
| func TestPower2Round(t *testing.T) { |
| t.Parallel() |
| for x := range uint32(q) { |
| rr, _ := fieldToMontgomery(x) |
| t1, t0 := power2Round(rr) |
| |
| hi, err := fieldToMontgomery(uint32(t1) << 13) |
| if err != nil { |
| t.Fatalf("power2Round(%d): failed to convert high part to Montgomery: %v", x, err) |
| } |
| if r := fieldFromMontgomery(fieldAdd(hi, t0)); r != x { |
| t.Fatalf("power2Round(%d) = (%d, %d), which reconstructs to %d, expected %d", x, t1, t0, r, x) |
| } |
| } |
| } |
| |
| func SpecDecompose(rr fieldElement, p parameters) (R1 uint32, R0 int32) { |
| r := fieldFromMontgomery(rr) |
| if (q-1)%p.γ2 != 0 { |
| panic("mldsa: internal error: unsupported denγ2") |
| } |
| γ2 := (q - 1) / uint32(p.γ2) |
| r0 := CenteredMod(r, 2*γ2) |
| diff := int32(r) - r0 |
| if diff == q-1 { |
| r0 = r0 - 1 |
| return 0, r0 |
| } else { |
| if diff < 0 || uint32(diff)%γ2 != 0 { |
| panic("mldsa: internal error: invalid decomposition") |
| } |
| r1 := uint32(diff) / (2 * γ2) |
| return r1, r0 |
| } |
| } |
| |
| func TestDecompose(t *testing.T) { |
| t.Run("ML-DSA-44", func(t *testing.T) { |
| testDecompose(t, params44) |
| }) |
| t.Run("ML-DSA-65,87", func(t *testing.T) { |
| testDecompose(t, params65) |
| }) |
| } |
| |
| func testDecompose(t *testing.T, p parameters) { |
| t.Parallel() |
| for x := range uint32(q) { |
| rr, _ := fieldToMontgomery(x) |
| r1, r0 := SpecDecompose(rr, p) |
| |
| // Check that SpecDecompose is correct. |
| // r ≡ r1 * (2 * γ2) + r0 mod q |
| γ2 := (q - 1) / uint32(p.γ2) |
| reconstructed := reduceModQ(int32(r1*2*γ2) + r0) |
| if reconstructed != x { |
| t.Fatalf("SpecDecompose(%d) = (%d, %d), which reconstructs to %d, expected %d", x, r1, r0, reconstructed, x) |
| } |
| |
| var gotR1 byte |
| var gotR0 int32 |
| switch p.γ2 { |
| case 88: |
| gotR1, gotR0 = decompose88(rr) |
| if gotR1 > 43 { |
| t.Fatalf("decompose88(%d) returned r1 = %d, which is out of range", x, gotR1) |
| } |
| case 32: |
| gotR1, gotR0 = decompose32(rr) |
| if gotR1 > 15 { |
| t.Fatalf("decompose32(%d) returned r1 = %d, which is out of range", x, gotR1) |
| } |
| default: |
| t.Fatalf("unsupported denγ2: %d", p.γ2) |
| } |
| if uint32(gotR1) != r1 { |
| t.Fatalf("highBits(%d) = %d, expected %d", x, gotR1, r1) |
| } |
| if gotR0 != r0 { |
| t.Fatalf("lowBits(%d) = %d, expected %d", x, gotR0, r0) |
| } |
| } |
| } |
| |
| func TestZetas(t *testing.T) { |
| ζ := big.NewInt(1753) |
| q := big.NewInt(q) |
| for k, zeta := range zetas { |
| // ζ^BitRev₈(k) mod q |
| exp := new(big.Int).Exp(ζ, big.NewInt(int64(BitRev8(uint8(k)))), q) |
| got := fieldFromMontgomery(zeta) |
| if big.NewInt(int64(got)).Cmp(exp) != 0 { |
| t.Errorf("zetas[%d] = %v, expected %v", k, got, exp) |
| } |
| } |
| } |
| |
| // TestAccumulated computes the hash of the following 12 values, as ASCII |
| // decimals with an optional leading - sign and separated by newlines, for all |
| // elements r in ℤq from 0 to q-1: |
| // |
| // - r mod± q |
| // - ‖r‖∞ = |r mod± q| |
| // - r1, r0 = Power2Round(r) |
| // |
| // For ML-DSA-44 (γ₂ = (q - 1) / 88): |
| // - HighBits(r) = UseHint(0, r) |
| // - UseHint(1, r) |
| // - LowBits(r) |
| // - ‖LowBits(r)‖∞ = |LowBits(r)| |
| // |
| // For ML-DSA-65 and ML-DSA-87 (γ₂ = (q - 1) / 32): |
| // - HighBits(r) = UseHint(0, r) |
| // - UseHint(1, r) |
| // - LowBits(r) |
| // - ‖LowBits(r)‖∞ = |LowBits(r)| |
| // |
| // Note that HighBits(r), LowBits(r) = Decompose(r). |
| func TestAccumulated(t *testing.T) { |
| if testing.Short() { |
| t.Skip("skipping accumulated test in short mode") |
| } |
| |
| o := sha3.NewShake128() |
| for x := range uint32(q) { |
| r, _ := fieldToMontgomery(x) |
| fmt.Fprintf(o, "%d\n", fieldCenteredMod(r)) |
| fmt.Fprintf(o, "%d\n", fieldInfinityNorm(r)) |
| |
| hi, lo := power2Round(r) |
| fmt.Fprintf(o, "%d\n", hi) |
| fmt.Fprintf(o, "%d\n", fieldFromMontgomery(lo)) |
| |
| r1, r0 := decompose88(r) |
| if r1x := highBits88(fieldFromMontgomery(r)); r1x != r1 { |
| t.Fatalf("highBits88(%d) = %d, expected %d", x, r1x, r1) |
| } |
| if r1h0 := useHint88(r, 0); r1h0 != r1 { |
| t.Fatalf("useHint88(%d, 0) = %d, expected %d", x, r1h0, r1) |
| } |
| |
| fmt.Fprintf(o, "%d\n", r1) |
| fmt.Fprintf(o, "%d\n", useHint88(r, 1)) |
| fmt.Fprintf(o, "%d\n", r0) |
| fmt.Fprintf(o, "%d\n", constantTimeAbs(r0)) |
| |
| r1, r0 = decompose32(r) |
| if r1x := highBits32(fieldFromMontgomery(r)); r1x != r1 { |
| t.Fatalf("highBits32(%d) = %d, expected %d", x, r1x, r1) |
| } |
| if r1h0 := useHint32(r, 0); r1h0 != r1 { |
| t.Fatalf("useHint32(%d, 0) = %d, expected %d", x, r1h0, r1) |
| } |
| |
| fmt.Fprintf(o, "%d\n", r1) |
| fmt.Fprintf(o, "%d\n", useHint32(r, 1)) |
| fmt.Fprintf(o, "%d\n", r0) |
| fmt.Fprintf(o, "%d\n", constantTimeAbs(r0)) |
| } |
| |
| // The expected value is documented at https://c2sp.org/CCTV/ML-DSA, and |
| // tested against https://github.com/FiloSottile/mldsa-py. |
| expected := "f930663417278156ab05d940294a77210a809c924d8ab63ec72f4526247602c7" |
| if got := hex.EncodeToString(o.Sum(nil)); got != expected { |
| t.Errorf("got %s, expected %s", got, expected) |
| } |
| } |