x/crypto/ssh: add padding oracle countermeasures for AES-CBC.
This deprives an attacker of feedback for guesses against the packet
length given by the connection dropping.
Change-Id: I14939a82e5243a86d192bb18be93d45589227147
Reviewed-on: https://go-review.googlesource.com/9908
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/cipher.go b/ssh/cipher.go
index b173183..3e06da0 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -14,6 +14,7 @@
"fmt"
"hash"
"io"
+ "io/ioutil"
)
const (
@@ -350,6 +351,7 @@
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
type cbcCipher struct {
mac hash.Hash
+ macSize uint32
decrypter cipher.BlockMode
encrypter cipher.BlockMode
@@ -357,6 +359,10 @@
seqNumBytes [4]byte
packetData []byte
macResult []byte
+
+ // Amount of data we should still read to hide which
+ // verification error triggered.
+ oracleCamouflage uint32
}
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
@@ -364,12 +370,18 @@
if err != nil {
return nil, err
}
- return &cbcCipher{
+
+ cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
encrypter: cipher.NewCBCEncrypter(c, iv),
packetData: make([]byte, 1024),
- }, nil
+ }
+ if cbc.mac != nil {
+ cbc.macSize = uint32(cbc.mac.Size())
+ }
+
+ return cbc, nil
}
func maxUInt32(a, b int) uint32 {
@@ -385,42 +397,58 @@
cbcMinPaddingSize = 4
)
+// cbcError represents a verification error that may leak information.
+type cbcError string
+
+func (e cbcError) Error() string { return string(e) }
+
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+ p, err := c.readPacketLeaky(seqNum, r)
+ if err != nil {
+ if _, ok := err.(cbcError); ok {
+ // Verification error: read a fixed amount of
+ // data, to make distinguishing between
+ // failing MAC and failing length check more
+ // difficult.
+ io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage))
+ }
+ }
+ return p, err
+}
+
+func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
blockSize := c.decrypter.BlockSize()
// Read the header, which will include some of the subsequent data in the
// case of block ciphers - this is copied back to the payload later.
// How many bytes of payload/padding will be read with this first read.
- firstBlockLength := (prefixLen + blockSize - 1) / blockSize * blockSize
+ firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
firstBlock := c.packetData[:firstBlockLength]
if _, err := io.ReadFull(r, firstBlock); err != nil {
return nil, err
}
+ c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
+
c.decrypter.CryptBlocks(firstBlock, firstBlock)
length := binary.BigEndian.Uint32(firstBlock[:4])
if length > maxPacket {
- return nil, errors.New("ssh: packet too large")
+ return nil, cbcError("ssh: packet too large")
}
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
// The minimum size of a packet is 16 (or the cipher block size, whichever
// is larger) bytes.
- return nil, errors.New("ssh: packet too small")
+ return nil, cbcError("ssh: packet too small")
}
// The length of the packet (including the length field but not the MAC) must
// be a multiple of the block size or 8, whichever is larger.
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
- return nil, errors.New("ssh: invalid packet length multiple")
+ return nil, cbcError("ssh: invalid packet length multiple")
}
paddingLength := uint32(firstBlock[4])
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
- return nil, errors.New("ssh: invalid packet length")
- }
-
- var macSize uint32
- if c.mac != nil {
- macSize = uint32(c.mac.Size())
+ return nil, cbcError("ssh: invalid packet length")
}
// Positions within the c.packetData buffer:
@@ -428,7 +456,7 @@
paddingStart := macStart - paddingLength
// Entire packet size, starting before length, ending at end of mac.
- entirePacketSize := macStart + macSize
+ entirePacketSize := macStart + c.macSize
// Ensure c.packetData is large enough for the entire packet data.
if uint32(cap(c.packetData)) < entirePacketSize {
@@ -440,8 +468,10 @@
c.packetData = c.packetData[:entirePacketSize]
}
- if _, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
+ if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
return nil, err
+ } else {
+ c.oracleCamouflage -= uint32(n)
}
remainingCrypted := c.packetData[firstBlockLength:macStart]
@@ -455,7 +485,7 @@
c.mac.Write(c.packetData[:macStart])
c.macResult = c.mac.Sum(c.macResult[:0])
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
- return nil, errors.New("ssh: MAC failure")
+ return nil, cbcError("ssh: MAC failure")
}
}
@@ -474,13 +504,9 @@
length := encLength - 4
paddingLength := int(length) - (1 + len(packet))
- var macSize uint32
- if c.mac != nil {
- macSize = uint32(c.mac.Size())
- }
// Overall buffer contains: header, payload, padding, mac.
// Space for the MAC is reserved in the capacity but not the slice length.
- bufferSize := encLength + macSize
+ bufferSize := encLength + c.macSize
if uint32(cap(c.packetData)) < bufferSize {
c.packetData = make([]byte, encLength, bufferSize)
} else {
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index 2fb75d0..d67a93a 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -62,3 +62,66 @@
}
}
}
+
+func TestCBCOracleCounterMeasure(t *testing.T) {
+ cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
+ defer delete(cipherModes, aes128cbcID)
+
+ kr := &kexResult{Hash: crypto.SHA1}
+ algs := directionAlgorithms{
+ Cipher: aes128cbcID,
+ MAC: "hmac-sha1",
+ Compression: "none",
+ }
+ client, err := newPacketCipher(clientKeys, algs, kr)
+ if err != nil {
+ t.Fatalf("newPacketCipher(client): %v", err)
+ }
+
+ want := "bla bla"
+ input := []byte(want)
+ buf := &bytes.Buffer{}
+ if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
+ t.Errorf("writePacket: %v", err)
+ }
+
+ packetSize := buf.Len()
+ buf.Write(make([]byte, 2*maxPacket))
+
+ // We corrupt each byte, but this usually will only test the
+ // 'packet too large' or 'MAC failure' cases.
+ lastRead := -1
+ for i := 0; i < packetSize; i++ {
+ server, err := newPacketCipher(clientKeys, algs, kr)
+ if err != nil {
+ t.Fatalf("newPacketCipher(client): %v", err)
+ }
+
+ fresh := &bytes.Buffer{}
+ fresh.Write(buf.Bytes())
+ fresh.Bytes()[i] ^= 0x01
+
+ before := fresh.Len()
+ _, err = server.readPacket(0, fresh)
+ if err == nil {
+ t.Errorf("corrupt byte %d: readPacket succeeded ", i)
+ continue
+ }
+ if _, ok := err.(cbcError); !ok {
+ t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
+ continue
+ }
+
+ after := fresh.Len()
+ bytesRead := before - after
+ if bytesRead < maxPacket {
+ t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
+ continue
+ }
+
+ if i > 0 && bytesRead != lastRead {
+ t.Errorf("corrupt byte %d: want %d, got %d bytes read", bytesRead, lastRead)
+ }
+ lastRead = bytesRead
+ }
+}