x/crypto/ssh: add padding oracle countermeasures for AES-CBC.

This deprives an attacker of feedback for guesses against the packet
length given by the connection dropping.

Change-Id: I14939a82e5243a86d192bb18be93d45589227147
Reviewed-on: https://go-review.googlesource.com/9908
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/cipher.go b/ssh/cipher.go
index b173183..3e06da0 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -14,6 +14,7 @@
 	"fmt"
 	"hash"
 	"io"
+	"io/ioutil"
 )
 
 const (
@@ -350,6 +351,7 @@
 // cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
 type cbcCipher struct {
 	mac       hash.Hash
+	macSize   uint32
 	decrypter cipher.BlockMode
 	encrypter cipher.BlockMode
 
@@ -357,6 +359,10 @@
 	seqNumBytes [4]byte
 	packetData  []byte
 	macResult   []byte
+
+	// Amount of data we should still read to hide which
+	// verification error triggered.
+	oracleCamouflage uint32
 }
 
 func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
@@ -364,12 +370,18 @@
 	if err != nil {
 		return nil, err
 	}
-	return &cbcCipher{
+
+	cbc := &cbcCipher{
 		mac:        macModes[algs.MAC].new(macKey),
 		decrypter:  cipher.NewCBCDecrypter(c, iv),
 		encrypter:  cipher.NewCBCEncrypter(c, iv),
 		packetData: make([]byte, 1024),
-	}, nil
+	}
+	if cbc.mac != nil {
+		cbc.macSize = uint32(cbc.mac.Size())
+	}
+
+	return cbc, nil
 }
 
 func maxUInt32(a, b int) uint32 {
@@ -385,42 +397,58 @@
 	cbcMinPaddingSize        = 4
 )
 
+// cbcError represents a verification error that may leak information.
+type cbcError string
+
+func (e cbcError) Error() string { return string(e) }
+
 func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
+	p, err := c.readPacketLeaky(seqNum, r)
+	if err != nil {
+		if _, ok := err.(cbcError); ok {
+			// Verification error: read a fixed amount of
+			// data, to make distinguishing between
+			// failing MAC and failing length check more
+			// difficult.
+			io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage))
+		}
+	}
+	return p, err
+}
+
+func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
 	blockSize := c.decrypter.BlockSize()
 
 	// Read the header, which will include some of the subsequent data in the
 	// case of block ciphers - this is copied back to the payload later.
 	// How many bytes of payload/padding will be read with this first read.
-	firstBlockLength := (prefixLen + blockSize - 1) / blockSize * blockSize
+	firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
 	firstBlock := c.packetData[:firstBlockLength]
 	if _, err := io.ReadFull(r, firstBlock); err != nil {
 		return nil, err
 	}
 
+	c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
+
 	c.decrypter.CryptBlocks(firstBlock, firstBlock)
 	length := binary.BigEndian.Uint32(firstBlock[:4])
 	if length > maxPacket {
-		return nil, errors.New("ssh: packet too large")
+		return nil, cbcError("ssh: packet too large")
 	}
 	if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
 		// The minimum size of a packet is 16 (or the cipher block size, whichever
 		// is larger) bytes.
-		return nil, errors.New("ssh: packet too small")
+		return nil, cbcError("ssh: packet too small")
 	}
 	// The length of the packet (including the length field but not the MAC) must
 	// be a multiple of the block size or 8, whichever is larger.
 	if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
-		return nil, errors.New("ssh: invalid packet length multiple")
+		return nil, cbcError("ssh: invalid packet length multiple")
 	}
 
 	paddingLength := uint32(firstBlock[4])
 	if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
-		return nil, errors.New("ssh: invalid packet length")
-	}
-
-	var macSize uint32
-	if c.mac != nil {
-		macSize = uint32(c.mac.Size())
+		return nil, cbcError("ssh: invalid packet length")
 	}
 
 	// Positions within the c.packetData buffer:
@@ -428,7 +456,7 @@
 	paddingStart := macStart - paddingLength
 
 	// Entire packet size, starting before length, ending at end of mac.
-	entirePacketSize := macStart + macSize
+	entirePacketSize := macStart + c.macSize
 
 	// Ensure c.packetData is large enough for the entire packet data.
 	if uint32(cap(c.packetData)) < entirePacketSize {
@@ -440,8 +468,10 @@
 		c.packetData = c.packetData[:entirePacketSize]
 	}
 
-	if _, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
+	if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil {
 		return nil, err
+	} else {
+		c.oracleCamouflage -= uint32(n)
 	}
 
 	remainingCrypted := c.packetData[firstBlockLength:macStart]
@@ -455,7 +485,7 @@
 		c.mac.Write(c.packetData[:macStart])
 		c.macResult = c.mac.Sum(c.macResult[:0])
 		if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
-			return nil, errors.New("ssh: MAC failure")
+			return nil, cbcError("ssh: MAC failure")
 		}
 	}
 
@@ -474,13 +504,9 @@
 	length := encLength - 4
 	paddingLength := int(length) - (1 + len(packet))
 
-	var macSize uint32
-	if c.mac != nil {
-		macSize = uint32(c.mac.Size())
-	}
 	// Overall buffer contains: header, payload, padding, mac.
 	// Space for the MAC is reserved in the capacity but not the slice length.
-	bufferSize := encLength + macSize
+	bufferSize := encLength + c.macSize
 	if uint32(cap(c.packetData)) < bufferSize {
 		c.packetData = make([]byte, encLength, bufferSize)
 	} else {
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index 2fb75d0..d67a93a 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -62,3 +62,66 @@
 		}
 	}
 }
+
+func TestCBCOracleCounterMeasure(t *testing.T) {
+	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
+	defer delete(cipherModes, aes128cbcID)
+
+	kr := &kexResult{Hash: crypto.SHA1}
+	algs := directionAlgorithms{
+		Cipher:      aes128cbcID,
+		MAC:         "hmac-sha1",
+		Compression: "none",
+	}
+	client, err := newPacketCipher(clientKeys, algs, kr)
+	if err != nil {
+		t.Fatalf("newPacketCipher(client): %v", err)
+	}
+
+	want := "bla bla"
+	input := []byte(want)
+	buf := &bytes.Buffer{}
+	if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	packetSize := buf.Len()
+	buf.Write(make([]byte, 2*maxPacket))
+
+	// We corrupt each byte, but this usually will only test the
+	// 'packet too large' or 'MAC failure' cases.
+	lastRead := -1
+	for i := 0; i < packetSize; i++ {
+		server, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Fatalf("newPacketCipher(client): %v", err)
+		}
+
+		fresh := &bytes.Buffer{}
+		fresh.Write(buf.Bytes())
+		fresh.Bytes()[i] ^= 0x01
+
+		before := fresh.Len()
+		_, err = server.readPacket(0, fresh)
+		if err == nil {
+			t.Errorf("corrupt byte %d: readPacket succeeded ", i)
+			continue
+		}
+		if _, ok := err.(cbcError); !ok {
+			t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
+			continue
+		}
+
+		after := fresh.Len()
+		bytesRead := before - after
+		if bytesRead < maxPacket {
+			t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
+			continue
+		}
+
+		if i > 0 && bytesRead != lastRead {
+			t.Errorf("corrupt byte %d: want %d, got %d bytes read", bytesRead, lastRead)
+		}
+		lastRead = bytesRead
+	}
+}