math/big: add (*Int).FillBytes

Replaced almost every use of Bytes with FillBytes.

Note that the approved proposal was for

    func (*Int) FillBytes(buf []byte)

while this implements

    func (*Int) FillBytes(buf []byte) []byte

because the latter was far nicer to use in all callsites.

Fixes #35833

Change-Id: Ia912df123e5d79b763845312ea3d9a8051343c0a
Reviewed-on: https://go-review.googlesource.com/c/go/+/230397
Reviewed-by: Robert Griesemer <gri@golang.org>
diff --git a/src/crypto/elliptic/elliptic.go b/src/crypto/elliptic/elliptic.go
index e2f71cd..bd5168c 100644
--- a/src/crypto/elliptic/elliptic.go
+++ b/src/crypto/elliptic/elliptic.go
@@ -277,7 +277,7 @@
 func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
 	N := curve.Params().N
 	bitSize := N.BitLen()
-	byteLen := (bitSize + 7) >> 3
+	byteLen := (bitSize + 7) / 8
 	priv = make([]byte, byteLen)
 
 	for x == nil {
@@ -304,15 +304,14 @@
 
 // Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
 func Marshal(curve Curve, x, y *big.Int) []byte {
-	byteLen := (curve.Params().BitSize + 7) >> 3
+	byteLen := (curve.Params().BitSize + 7) / 8
 
 	ret := make([]byte, 1+2*byteLen)
 	ret[0] = 4 // uncompressed point
 
-	xBytes := x.Bytes()
-	copy(ret[1+byteLen-len(xBytes):], xBytes)
-	yBytes := y.Bytes()
-	copy(ret[1+2*byteLen-len(yBytes):], yBytes)
+	x.FillBytes(ret[1 : 1+byteLen])
+	y.FillBytes(ret[1+byteLen : 1+2*byteLen])
+
 	return ret
 }
 
@@ -320,7 +319,7 @@
 // It is an error if the point is not in uncompressed form or is not on the curve.
 // On error, x = nil.
 func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
-	byteLen := (curve.Params().BitSize + 7) >> 3
+	byteLen := (curve.Params().BitSize + 7) / 8
 	if len(data) != 1+2*byteLen {
 		return
 	}
diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go
index 499242f..3208119 100644
--- a/src/crypto/rsa/pkcs1v15.go
+++ b/src/crypto/rsa/pkcs1v15.go
@@ -61,8 +61,7 @@
 	m := new(big.Int).SetBytes(em)
 	c := encrypt(new(big.Int), pub, m)
 
-	copyWithLeftPad(em, c.Bytes())
-	return em, nil
+	return c.FillBytes(em), nil
 }
 
 // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
@@ -150,7 +149,7 @@
 		return
 	}
 
-	em = leftPad(m.Bytes(), k)
+	em = m.FillBytes(make([]byte, k))
 	firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
 	secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
 
@@ -256,8 +255,7 @@
 		return nil, err
 	}
 
-	copyWithLeftPad(em, c.Bytes())
-	return em, nil
+	return c.FillBytes(em), nil
 }
 
 // VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
@@ -286,7 +284,7 @@
 
 	c := new(big.Int).SetBytes(sig)
 	m := encrypt(new(big.Int), pub, c)
-	em := leftPad(m.Bytes(), k)
+	em := m.FillBytes(make([]byte, k))
 	// EM = 0x00 || 0x01 || PS || 0x00 || T
 
 	ok := subtle.ConstantTimeByteEq(em[0], 0)
@@ -323,13 +321,3 @@
 	}
 	return
 }
-
-// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
-// needed.
-func copyWithLeftPad(dest, src []byte) {
-	numPaddingBytes := len(dest) - len(src)
-	for i := 0; i < numPaddingBytes; i++ {
-		dest[i] = 0
-	}
-	copy(dest[numPaddingBytes:], src)
-}
diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go
index f9844d8..b2adbed 100644
--- a/src/crypto/rsa/pss.go
+++ b/src/crypto/rsa/pss.go
@@ -207,20 +207,19 @@
 // Note that hashed must be the result of hashing the input message using the
 // given hash function. salt is a random sequence of bytes whose length will be
 // later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
 	emBits := priv.N.BitLen() - 1
 	em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
 	if err != nil {
-		return
+		return nil, err
 	}
 	m := new(big.Int).SetBytes(em)
 	c, err := decryptAndCheck(rand, priv, m)
 	if err != nil {
-		return
+		return nil, err
 	}
-	s = make([]byte, priv.Size())
-	copyWithLeftPad(s, c.Bytes())
-	return
+	s := make([]byte, priv.Size())
+	return c.FillBytes(s), nil
 }
 
 const (
@@ -296,11 +295,9 @@
 	m := encrypt(new(big.Int), pub, s)
 	emBits := pub.N.BitLen() - 1
 	emLen := (emBits + 7) / 8
-	emBytes := m.Bytes()
-	if emLen < len(emBytes) {
+	if m.BitLen() > emLen*8 {
 		return ErrVerification
 	}
-	em := make([]byte, emLen)
-	copyWithLeftPad(em, emBytes)
+	em := m.FillBytes(make([]byte, emLen))
 	return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
 }
diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go
index b4bfa13..28eb592 100644
--- a/src/crypto/rsa/rsa.go
+++ b/src/crypto/rsa/rsa.go
@@ -416,16 +416,9 @@
 	m := new(big.Int)
 	m.SetBytes(em)
 	c := encrypt(new(big.Int), pub, m)
-	out := c.Bytes()
 
-	if len(out) < k {
-		// If the output is too small, we need to left-pad with zeros.
-		t := make([]byte, k)
-		copy(t[k-len(out):], out)
-		out = t
-	}
-
-	return out, nil
+	out := make([]byte, k)
+	return c.FillBytes(out), nil
 }
 
 // ErrDecryption represents a failure to decrypt a message.
@@ -597,12 +590,9 @@
 	lHash := hash.Sum(nil)
 	hash.Reset()
 
-	// Converting the plaintext number to bytes will strip any
-	// leading zeros so we may have to left pad. We do this unconditionally
-	// to avoid leaking timing information. (Although we still probably
-	// leak the number of leading zeros. It's not clear that we can do
-	// anything about this.)
-	em := leftPad(m.Bytes(), k)
+	// We probably leak the number of leading zeros.
+	// It's not clear that we can do anything about this.
+	em := m.FillBytes(make([]byte, k))
 
 	firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
 
@@ -643,15 +633,3 @@
 
 	return rest[index+1:], nil
 }
-
-// leftPad returns a new slice of length size. The contents of input are right
-// aligned in the new slice.
-func leftPad(input []byte, size int) (out []byte) {
-	n := len(input)
-	if n > size {
-		n = size
-	}
-	out = make([]byte, size)
-	copy(out[len(out)-n:], input)
-	return
-}
diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go
index 2aab323..3140169 100644
--- a/src/crypto/tls/key_schedule.go
+++ b/src/crypto/tls/key_schedule.go
@@ -173,11 +173,8 @@
 	}
 
 	xShared, _ := curve.ScalarMult(x, y, p.privateKey)
-	sharedKey := make([]byte, (curve.Params().BitSize+7)>>3)
-	xBytes := xShared.Bytes()
-	copy(sharedKey[len(sharedKey)-len(xBytes):], xBytes)
-
-	return sharedKey
+	sharedKey := make([]byte, (curve.Params().BitSize+7)/8)
+	return xShared.FillBytes(sharedKey)
 }
 
 type x25519Parameters struct {
diff --git a/src/crypto/x509/sec1.go b/src/crypto/x509/sec1.go
index 0bfb90c..52c108f 100644
--- a/src/crypto/x509/sec1.go
+++ b/src/crypto/x509/sec1.go
@@ -52,13 +52,10 @@
 // marshalECPrivateKey marshals an EC private key into ASN.1, DER format and
 // sets the curve ID to the given OID, or omits it if OID is nil.
 func marshalECPrivateKeyWithOID(key *ecdsa.PrivateKey, oid asn1.ObjectIdentifier) ([]byte, error) {
-	privateKeyBytes := key.D.Bytes()
-	paddedPrivateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
-	copy(paddedPrivateKey[len(paddedPrivateKey)-len(privateKeyBytes):], privateKeyBytes)
-
+	privateKey := make([]byte, (key.Curve.Params().N.BitLen()+7)/8)
 	return asn1.Marshal(ecPrivateKey{
 		Version:       1,
-		PrivateKey:    paddedPrivateKey,
+		PrivateKey:    key.D.FillBytes(privateKey),
 		NamedCurveOID: oid,
 		PublicKey:     asn1.BitString{Bytes: elliptic.Marshal(key.Curve, key.X, key.Y)},
 	})
diff --git a/src/math/big/int.go b/src/math/big/int.go
index 8816cf5..65f3248 100644
--- a/src/math/big/int.go
+++ b/src/math/big/int.go
@@ -447,11 +447,26 @@
 }
 
 // Bytes returns the absolute value of x as a big-endian byte slice.
+//
+// To use a fixed length slice, or a preallocated one, use FillBytes.
 func (x *Int) Bytes() []byte {
 	buf := make([]byte, len(x.abs)*_S)
 	return buf[x.abs.bytes(buf):]
 }
 
+// FillBytes sets buf to the absolute value of x, storing it as a zero-extended
+// big-endian byte slice, and returns buf.
+//
+// If the absolute value of x doesn't fit in buf, FillBytes will panic.
+func (x *Int) FillBytes(buf []byte) []byte {
+	// Clear whole buffer. (This gets optimized into a memclr.)
+	for i := range buf {
+		buf[i] = 0
+	}
+	x.abs.bytes(buf)
+	return buf
+}
+
 // BitLen returns the length of the absolute value of x in bits.
 // The bit length of 0 is 0.
 func (x *Int) BitLen() int {
diff --git a/src/math/big/int_test.go b/src/math/big/int_test.go
index e3a1587..3c85573 100644
--- a/src/math/big/int_test.go
+++ b/src/math/big/int_test.go
@@ -1840,3 +1840,57 @@
 		})
 	}
 }
+
+func TestFillBytes(t *testing.T) {
+	checkResult := func(t *testing.T, buf []byte, want *Int) {
+		t.Helper()
+		got := new(Int).SetBytes(buf)
+		if got.CmpAbs(want) != 0 {
+			t.Errorf("got 0x%x, want 0x%x: %x", got, want, buf)
+		}
+	}
+	panics := func(f func()) (panic bool) {
+		defer func() { panic = recover() != nil }()
+		f()
+		return
+	}
+
+	for _, n := range []string{
+		"0",
+		"1000",
+		"0xffffffff",
+		"-0xffffffff",
+		"0xffffffffffffffff",
+		"0x10000000000000000",
+		"0xabababababababababababababababababababababababababa",
+		"0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
+	} {
+		t.Run(n, func(t *testing.T) {
+			t.Logf(n)
+			x, ok := new(Int).SetString(n, 0)
+			if !ok {
+				panic("invalid test entry")
+			}
+
+			// Perfectly sized buffer.
+			byteLen := (x.BitLen() + 7) / 8
+			buf := make([]byte, byteLen)
+			checkResult(t, x.FillBytes(buf), x)
+
+			// Way larger, checking all bytes get zeroed.
+			buf = make([]byte, 100)
+			for i := range buf {
+				buf[i] = 0xff
+			}
+			checkResult(t, x.FillBytes(buf), x)
+
+			// Too small.
+			if byteLen > 0 {
+				buf = make([]byte, byteLen-1)
+				if !panics(func() { x.FillBytes(buf) }) {
+					t.Errorf("expected panic for small buffer and value %x", x)
+				}
+			}
+		})
+	}
+}
diff --git a/src/math/big/nat.go b/src/math/big/nat.go
index c31ec51..6a3989b 100644
--- a/src/math/big/nat.go
+++ b/src/math/big/nat.go
@@ -1476,19 +1476,26 @@
 }
 
 // bytes writes the value of z into buf using big-endian encoding.
-// len(buf) must be >= len(z)*_S. The value of z is encoded in the
-// slice buf[i:]. The number i of unused bytes at the beginning of
-// buf is returned as result.
+// The value of z is encoded in the slice buf[i:]. If the value of z
+// cannot be represented in buf, bytes panics. The number i of unused
+// bytes at the beginning of buf is returned as result.
 func (z nat) bytes(buf []byte) (i int) {
 	i = len(buf)
 	for _, d := range z {
 		for j := 0; j < _S; j++ {
 			i--
-			buf[i] = byte(d)
+			if i >= 0 {
+				buf[i] = byte(d)
+			} else if byte(d) != 0 {
+				panic("math/big: buffer too small to fit value")
+			}
 			d >>= 8
 		}
 	}
 
+	if i < 0 {
+		i = 0
+	}
 	for i < len(buf) && buf[i] == 0 {
 		i++
 	}