x/crypto/ssh: add padding oracle countermeasures for AES-CBC.

This deprives an attacker of feedback for guesses against the packet
length given by the connection dropping.

Change-Id: I14939a82e5243a86d192bb18be93d45589227147
Reviewed-on: https://go-review.googlesource.com/9908
Reviewed-by: Adam Langley <agl@golang.org>
diff --git a/ssh/cipher_test.go b/ssh/cipher_test.go
index 2fb75d0..d67a93a 100644
--- a/ssh/cipher_test.go
+++ b/ssh/cipher_test.go
@@ -62,3 +62,66 @@
 		}
 	}
 }
+
+func TestCBCOracleCounterMeasure(t *testing.T) {
+	cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil}
+	defer delete(cipherModes, aes128cbcID)
+
+	kr := &kexResult{Hash: crypto.SHA1}
+	algs := directionAlgorithms{
+		Cipher:      aes128cbcID,
+		MAC:         "hmac-sha1",
+		Compression: "none",
+	}
+	client, err := newPacketCipher(clientKeys, algs, kr)
+	if err != nil {
+		t.Fatalf("newPacketCipher(client): %v", err)
+	}
+
+	want := "bla bla"
+	input := []byte(want)
+	buf := &bytes.Buffer{}
+	if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
+		t.Errorf("writePacket: %v", err)
+	}
+
+	packetSize := buf.Len()
+	buf.Write(make([]byte, 2*maxPacket))
+
+	// We corrupt each byte, but this usually will only test the
+	// 'packet too large' or 'MAC failure' cases.
+	lastRead := -1
+	for i := 0; i < packetSize; i++ {
+		server, err := newPacketCipher(clientKeys, algs, kr)
+		if err != nil {
+			t.Fatalf("newPacketCipher(client): %v", err)
+		}
+
+		fresh := &bytes.Buffer{}
+		fresh.Write(buf.Bytes())
+		fresh.Bytes()[i] ^= 0x01
+
+		before := fresh.Len()
+		_, err = server.readPacket(0, fresh)
+		if err == nil {
+			t.Errorf("corrupt byte %d: readPacket succeeded ", i)
+			continue
+		}
+		if _, ok := err.(cbcError); !ok {
+			t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
+			continue
+		}
+
+		after := fresh.Len()
+		bytesRead := before - after
+		if bytesRead < maxPacket {
+			t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
+			continue
+		}
+
+		if i > 0 && bytesRead != lastRead {
+			t.Errorf("corrupt byte %d: want %d, got %d bytes read", bytesRead, lastRead)
+		}
+		lastRead = bytesRead
+	}
+}