ssh: cleanup cipher creation logic

Change-Id: I0e6ac0a381ffa53650304f0bea2ba79c3cf1d8c2
Reviewed-on: https://go-review.googlesource.com/87196
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/ssh/cipher.go b/ssh/cipher.go
index 9944139..f2ba117 100644
--- a/ssh/cipher.go
+++ b/ssh/cipher.go
@@ -56,78 +56,78 @@
 	return rc4.NewCipher(key)
 }
 
-type streamCipherMode struct {
-	keySize    int
-	ivSize     int
-	skip       int
-	createFunc func(key, iv []byte) (cipher.Stream, error)
+type cipherMode struct {
+	keySize int
+	ivSize  int
+	create  func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
 }
 
-func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
-	if len(key) < c.keySize {
-		panic("ssh: key length too small for cipher")
-	}
-	if len(iv) < c.ivSize {
-		panic("ssh: iv too small for cipher")
-	}
-
-	stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
-	if err != nil {
-		return nil, err
-	}
-
-	var streamDump []byte
-	if c.skip > 0 {
-		streamDump = make([]byte, 512)
-	}
-
-	for remainingToDump := c.skip; remainingToDump > 0; {
-		dumpThisTime := remainingToDump
-		if dumpThisTime > len(streamDump) {
-			dumpThisTime = len(streamDump)
+func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
+	return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
+		stream, err := createFunc(key, iv)
+		if err != nil {
+			return nil, err
 		}
-		stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
-		remainingToDump -= dumpThisTime
-	}
 
-	return stream, nil
+		var streamDump []byte
+		if skip > 0 {
+			streamDump = make([]byte, 512)
+		}
+
+		for remainingToDump := skip; remainingToDump > 0; {
+			dumpThisTime := remainingToDump
+			if dumpThisTime > len(streamDump) {
+				dumpThisTime = len(streamDump)
+			}
+			stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
+			remainingToDump -= dumpThisTime
+		}
+
+		mac := macModes[algs.MAC].new(macKey)
+		return &streamPacketCipher{
+			mac:       mac,
+			etm:       macModes[algs.MAC].etm,
+			macResult: make([]byte, mac.Size()),
+			cipher:    stream,
+		}, nil
+	}
 }
 
 // cipherModes documents properties of supported ciphers. Ciphers not included
 // are not supported and will not be negotiated, even if explicitly requested in
 // ClientConfig.Crypto.Ciphers.
-var cipherModes = map[string]*streamCipherMode{
+var cipherModes = map[string]*cipherMode{
 	// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
 	// are defined in the order specified in the RFC.
-	"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
-	"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
-	"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
+	"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
+	"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
+	"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
 
 	// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
 	// They are defined in the order specified in the RFC.
-	"arcfour128": {16, 0, 1536, newRC4},
-	"arcfour256": {32, 0, 1536, newRC4},
+	"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
+	"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
 
 	// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
 	// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
 	// RC4) has problems with weak keys, and should be used with caution."
 	// RFC4345 introduces improved versions of Arcfour.
-	"arcfour": {16, 0, 0, newRC4},
+	"arcfour": {16, 0, streamCipherMode(0, newRC4)},
 
-	// AEAD ciphers are special cased. If we add any more non-stream
-	// ciphers, we should create a cleaner way to do this.
-	gcmCipherID:        {16, 12, 0, nil},
-	chacha20Poly1305ID: {64, 0, 0, nil},
+	// AEAD ciphers
+	gcmCipherID:        {16, 12, newGCMCipher},
+	chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
 
 	// CBC mode is insecure and so is not included in the default config.
 	// (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely
 	// needed, it's possible to specify a custom Config to enable it.
 	// You should expect that an active attacker can recover plaintext if
 	// you do.
-	aes128cbcID: {16, aes.BlockSize, 0, nil},
+	aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
 
-	// 3des-cbc is insecure and is disabled by default.
-	tripledescbcID: {24, des.BlockSize, 0, nil},
+	// 3des-cbc is insecure and is not included in the default
+	// config.
+	tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
 }
 
 // prefixLen is the length of the packet prefix that contains the packet length
@@ -307,7 +307,7 @@
 	buf    []byte
 }
 
-func newGCMCipher(iv, key []byte) (packetCipher, error) {
+func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
 	c, err := aes.NewCipher(key)
 	if err != nil {
 		return nil, err
@@ -425,7 +425,7 @@
 	oracleCamouflage uint32
 }
 
-func newCBCCipher(c cipher.Block, iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
+func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
 	cbc := &cbcCipher{
 		mac:        macModes[algs.MAC].new(macKey),
 		decrypter:  cipher.NewCBCDecrypter(c, iv),
@@ -439,13 +439,13 @@
 	return cbc, nil
 }
 
-func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
+func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
 	c, err := aes.NewCipher(key)
 	if err != nil {
 		return nil, err
 	}
 
-	cbc, err := newCBCCipher(c, iv, key, macKey, algs)
+	cbc, err := newCBCCipher(c, key, iv, macKey, algs)
 	if err != nil {
 		return nil, err
 	}
@@ -453,13 +453,13 @@
 	return cbc, nil
 }
 
-func newTripleDESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
+func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
 	c, err := des.NewTripleDESCipher(key)
 	if err != nil {
 		return nil, err
 	}
 
-	cbc, err := newCBCCipher(c, iv, key, macKey, algs)
+	cbc, err := newCBCCipher(c, key, iv, macKey, algs)
 	if err != nil {
 		return nil, err
 	}
@@ -646,9 +646,9 @@
 	buf        []byte
 }
 
-func newChaCha20Cipher(key []byte) (packetCipher, error) {
+func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
 	if len(key) != 64 {
-		panic("key length")
+		panic(len(key))
 	}
 
 	c := &chacha20Poly1305Cipher{
diff --git a/ssh/transport.go b/ssh/transport.go
index 5a56753..f6fae1d 100644
--- a/ssh/transport.go
+++ b/ssh/transport.go
@@ -233,51 +233,22 @@
 	clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
 )
 
-// generateKeys generates key material for IV, MAC and encryption.
-func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
-	cipherMode := cipherModes[algs.Cipher]
-	macMode := macModes[algs.MAC]
-
-	iv = make([]byte, cipherMode.ivSize)
-	key = make([]byte, cipherMode.keySize)
-	macKey = make([]byte, macMode.keySize)
-
-	generateKeyMaterial(iv, d.ivTag, kex)
-	generateKeyMaterial(key, d.keyTag, kex)
-	generateKeyMaterial(macKey, d.macKeyTag, kex)
-	return
-}
-
 // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
 // described in RFC 4253, section 6.4. direction should either be serverKeys
 // (to setup server->client keys) or clientKeys (for client->server keys).
 func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
-	iv, key, macKey := generateKeys(d, algs, kex)
+	cipherMode := cipherModes[algs.Cipher]
+	macMode := macModes[algs.MAC]
 
-	switch algs.Cipher {
-	case chacha20Poly1305ID:
-		return newChaCha20Cipher(key)
-	case gcmCipherID:
-		return newGCMCipher(iv, key)
-	case aes128cbcID:
-		return newAESCBCCipher(iv, key, macKey, algs)
-	case tripledescbcID:
-		return newTripleDESCBCCipher(iv, key, macKey, algs)
-	}
+	iv := make([]byte, cipherMode.ivSize)
+	key := make([]byte, cipherMode.keySize)
+	macKey := make([]byte, macMode.keySize)
 
-	c := &streamPacketCipher{
-		mac: macModes[algs.MAC].new(macKey),
-		etm: macModes[algs.MAC].etm,
-	}
-	c.macResult = make([]byte, c.mac.Size())
+	generateKeyMaterial(iv, d.ivTag, kex)
+	generateKeyMaterial(key, d.keyTag, kex)
+	generateKeyMaterial(macKey, d.macKeyTag, kex)
 
-	var err error
-	c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
-	if err != nil {
-		return nil, err
-	}
-
-	return c, nil
+	return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
 }
 
 // generateKeyMaterial fills out with key material generated from tag, K, H