blob: eff3371871345d9c1140ccdcbb3c924613fa6648 [file] [log] [blame] [edit]
// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mldsa
import (
"crypto/internal/fips140/drbg"
"errors"
"math/bits"
)
// FIPS 204 defines a needless semi-expanded format for private keys. This is
// not a good format for key storage and exchange, because it is large and
// requires careful parsing to reject malformed keys. Seeds instead are just 32
// bytes, are always valid, and always expand to valid keys in memory. It is
// *also* a poor in-memory format, because it defers computing the NTT of s1,
// s2, and t0 and the expansion of A until signing time, which is inefficient.
// For a hot second, it looked like we could have all agreed to only use seeds,
// but unfortunately OpenSSL and BouncyCastle lobbied hard against that during
// the WGLC of the LAMPS IETF working group. Also, ACVP tests provide and expect
// semi-expanded keys, so we implement them here for testing purposes.
func semiExpandedPrivKeySize(p parameters) int {
k, l := p.k, p.l
ηBitlen := bits.Len(uint(p.η)) + 1
// ρ + K + tr + l × n × η-bit coefficients of s₁ +
// k × n × η-bit coefficients of s₂ + k × n × 13-bit coefficients of t₀
return 32 + 32 + 64 + l*n*ηBitlen/8 + k*n*ηBitlen/8 + k*n*13/8
}
// TestingOnlyNewPrivateKeyFromSemiExpanded creates a PrivateKey from a
// semi-expanded private key encoding, for testing purposes. It rejects
// inconsistent keys.
//
// [PrivateKey.Bytes] must NOT be called on the resulting key, as it will
// produce a random value.
func TestingOnlyNewPrivateKeyFromSemiExpanded(sk []byte) (*PrivateKey, error) {
var p parameters
switch len(sk) {
case semiExpandedPrivKeySize(params44):
p = params44
case semiExpandedPrivKeySize(params65):
p = params65
case semiExpandedPrivKeySize(params87):
p = params87
default:
return nil, errors.New("mldsa: invalid semi-expanded private key size")
}
k, l := p.k, p.l
ρ, K, tr, s1, s2, t0, err := skDecode(sk, p)
if err != nil {
return nil, err
}
priv := &PrivateKey{pub: PublicKey{p: p}}
priv.k = K
priv.pub.tr = tr
A := priv.pub.a[:k*l]
computeMatrixA(A, ρ[:], p)
for r := range l {
priv.s1[r] = ntt(s1[r])
}
for r := range k {
priv.s2[r] = ntt(s2[r])
}
for r := range k {
priv.t0[r] = ntt(t0[r])
}
// We need to put something in priv.seed, and putting random bytes feels
// safer than putting anything predictable.
drbg.Read(priv.seed[:])
// Making this format *even more* annoying, we need to recompute t1 from ρ,
// s1, and s2 if we want to generate the public key. This is essentially as
// much work as regenerating everything from seed.
//
// You might also notice that the semi-expanded format also stores t0 and a
// hash of the public key, though. How are we supposed to check they are
// consistent without regenerating the public key? Do we even need to check?
// Who knows! FIPS 204 says
//
// > Note that there exist malformed inputs that can cause skDecode to
// > return values that are not in the correct range. Hence, skDecode
// > should only be run on inputs that come from trusted sources.
//
// so it sounds like it doesn't even want us to check the coefficients are
// within bounds, but especially if using this format for key exchange, that
// sounds like a bad idea. So we check everything.
t1 := make([][n]uint16, k, maxK)
for i := range k {
tHat := priv.s2[i]
for j := range l {
tHat = polyAdd(tHat, nttMul(A[i*l+j], priv.s1[j]))
}
t := inverseNTT(tHat)
for j := range n {
r1, r0 := power2Round(t[j])
t1[i][j] = r1
if r0 != t0[i][j] {
return nil, errors.New("mldsa: semi-expanded private key inconsistent with t0")
}
}
}
pk := pkEncode(priv.pub.raw[:0], ρ[:], t1, p)
if computePublicKeyHash(pk) != tr {
return nil, errors.New("mldsa: semi-expanded private key inconsistent with public key hash")
}
computeT1Hat(priv.pub.t1[:k], t1) // NTT(t₁ ⋅ 2ᵈ)
return priv, nil
}
func TestingOnlyPrivateKeySemiExpandedBytes(priv *PrivateKey) []byte {
k, l, η := priv.pub.p.k, priv.pub.p.l, priv.pub.p.η
sk := make([]byte, 0, semiExpandedPrivKeySize(priv.pub.p))
sk = append(sk, priv.pub.raw[:32]...) // ρ
sk = append(sk, priv.k[:]...) // K
sk = append(sk, priv.pub.tr[:]...) // tr
for i := range l {
sk = bitPackSlow(sk, inverseNTT(priv.s1[i]), η, η)
}
for i := range k {
sk = bitPackSlow(sk, inverseNTT(priv.s2[i]), η, η)
}
const bound = 1 << (13 - 1) // 2^(d-1)
for i := range k {
sk = bitPackSlow(sk, inverseNTT(priv.t0[i]), bound-1, bound)
}
return sk
}
func skDecode(sk []byte, p parameters) (ρ, K [32]byte, tr [64]byte, s1, s2, t0 []ringElement, err error) {
k, l, η := p.k, p.l, p.η
if len(sk) != semiExpandedPrivKeySize(p) {
err = errors.New("mldsa: invalid semi-expanded private key size")
return
}
copy(ρ[:], sk[:32])
sk = sk[32:]
copy(K[:], sk[:32])
sk = sk[32:]
copy(tr[:], sk[:64])
sk = sk[64:]
s1 = make([]ringElement, l)
for i := range l {
length := n * bits.Len(uint(η)*2) / 8
s1[i], err = bitUnpackSlow(sk[:length], η, η)
if err != nil {
return
}
sk = sk[length:]
}
s2 = make([]ringElement, k)
for i := range k {
length := n * bits.Len(uint(η)*2) / 8
s2[i], err = bitUnpackSlow(sk[:length], η, η)
if err != nil {
return
}
sk = sk[length:]
}
const bound = 1 << (13 - 1) // 2^(d-1)
t0 = make([]ringElement, k)
for i := range k {
length := n * 13 / 8
t0[i], err = bitUnpackSlow(sk[:length], bound-1, bound)
if err != nil {
return
}
sk = sk[length:]
}
return
}
func bitPackSlow(buf []byte, r ringElement, a, b int) []byte {
bitlen := bits.Len(uint(a + b))
if bitlen <= 0 || bitlen > 16 {
panic("mldsa: internal error: invalid bitlen")
}
out, v := sliceForAppend(buf, n*bitlen/8)
var acc uint32
var accBits uint
for i := range r {
w := int32(b) - fieldCenteredMod(r[i])
acc |= uint32(w) << accBits
accBits += uint(bitlen)
for accBits >= 8 {
v[0] = byte(acc)
v = v[1:]
acc >>= 8
accBits -= 8
}
}
if accBits > 0 {
v[0] = byte(acc)
}
return out
}
func bitUnpackSlow(v []byte, a, b int) (ringElement, error) {
bitlen := bits.Len(uint(a + b))
if bitlen <= 0 || bitlen > 16 {
panic("mldsa: internal error: invalid bitlen")
}
if len(v) != n*bitlen/8 {
return ringElement{}, errors.New("mldsa: invalid input length for bitUnpackSlow")
}
mask := uint32((1 << bitlen) - 1)
maxValue := uint32(a + b)
var r ringElement
var acc uint32
var accBits uint
vIdx := 0
for i := range r {
for accBits < uint(bitlen) {
if vIdx < len(v) {
acc |= uint32(v[vIdx]) << accBits
vIdx++
accBits += 8
}
}
w := acc & mask
if w > maxValue {
return ringElement{}, errors.New("mldsa: coefficient out of range")
}
r[i] = fieldSubToMontgomery(uint32(b), w)
acc >>= bitlen
accBits -= uint(bitlen)
}
return r, nil
}