ssh: support encrypted OpenSSH private keys

Includes the bcrypt_pbkdf package by Dmitry Chestnykh,
submitted with permission on his behalf under the CLA:
https://go-review.googlesource.com/c/crypto/+/207600/2#message-6a035dd62ff76f6c9367299b911076a1be237fb8

Fixes golang/go#18692

Change-Id: I74e3ab355a8d720948d64d87adc009783a9d9732
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/207600
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go b/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go
new file mode 100644
index 0000000..af81d26
--- /dev/null
+++ b/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf.go
@@ -0,0 +1,93 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package bcrypt_pbkdf implements bcrypt_pbkdf(3) from OpenBSD.
+//
+// See https://flak.tedunangst.com/post/bcrypt-pbkdf and
+// https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/lib/libutil/bcrypt_pbkdf.c.
+package bcrypt_pbkdf
+
+import (
+	"crypto/sha512"
+	"errors"
+	"golang.org/x/crypto/blowfish"
+)
+
+const blockSize = 32
+
+// Key derives a key from the password, salt and rounds count, returning a
+// []byte of length keyLen that can be used as cryptographic key.
+func Key(password, salt []byte, rounds, keyLen int) ([]byte, error) {
+	if rounds < 1 {
+		return nil, errors.New("bcrypt_pbkdf: number of rounds is too small")
+	}
+	if len(password) == 0 {
+		return nil, errors.New("bcrypt_pbkdf: empty password")
+	}
+	if len(salt) == 0 || len(salt) > 1<<20 {
+		return nil, errors.New("bcrypt_pbkdf: bad salt length")
+	}
+	if keyLen > 1024 {
+		return nil, errors.New("bcrypt_pbkdf: keyLen is too large")
+	}
+
+	numBlocks := (keyLen + blockSize - 1) / blockSize
+	key := make([]byte, numBlocks*blockSize)
+
+	h := sha512.New()
+	h.Write(password)
+	shapass := h.Sum(nil)
+
+	shasalt := make([]byte, 0, sha512.Size)
+	cnt, tmp := make([]byte, 4), make([]byte, blockSize)
+	for block := 1; block <= numBlocks; block++ {
+		h.Reset()
+		h.Write(salt)
+		cnt[0] = byte(block >> 24)
+		cnt[1] = byte(block >> 16)
+		cnt[2] = byte(block >> 8)
+		cnt[3] = byte(block)
+		h.Write(cnt)
+		bcryptHash(tmp, shapass, h.Sum(shasalt))
+
+		out := make([]byte, blockSize)
+		copy(out, tmp)
+		for i := 2; i <= rounds; i++ {
+			h.Reset()
+			h.Write(tmp)
+			bcryptHash(tmp, shapass, h.Sum(shasalt))
+			for j := 0; j < len(out); j++ {
+				out[j] ^= tmp[j]
+			}
+		}
+
+		for i, v := range out {
+			key[i*numBlocks+(block-1)] = v
+		}
+	}
+	return key[:keyLen], nil
+}
+
+var magic = []byte("OxychromaticBlowfishSwatDynamite")
+
+func bcryptHash(out, shapass, shasalt []byte) {
+	c, err := blowfish.NewSaltedCipher(shapass, shasalt)
+	if err != nil {
+		panic(err)
+	}
+	for i := 0; i < 64; i++ {
+		blowfish.ExpandKey(shasalt, c)
+		blowfish.ExpandKey(shapass, c)
+	}
+	copy(out, magic)
+	for i := 0; i < 32; i += 8 {
+		for j := 0; j < 64; j++ {
+			c.Encrypt(out[i:i+8], out[i:i+8])
+		}
+	}
+	// Swap bytes due to different endianness.
+	for i := 0; i < 32; i += 4 {
+		out[i+3], out[i+2], out[i+1], out[i] = out[i], out[i+1], out[i+2], out[i+3]
+	}
+}
diff --git a/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf_test.go b/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf_test.go
new file mode 100644
index 0000000..20b7889
--- /dev/null
+++ b/ssh/internal/bcrypt_pbkdf/bcrypt_pbkdf_test.go
@@ -0,0 +1,97 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package bcrypt_pbkdf
+
+import (
+	"bytes"
+	"testing"
+)
+
+// Test vectors generated by the reference implementation from OpenBSD.
+var golden = []struct {
+	rounds                 int
+	password, salt, result []byte
+}{
+	{
+		12,
+		[]byte("password"),
+		[]byte("salt"),
+		[]byte{
+			0x1a, 0xe4, 0x2c, 0x05, 0xd4, 0x87, 0xbc, 0x02, 0xf6,
+			0x49, 0x21, 0xa4, 0xeb, 0xe4, 0xea, 0x93, 0xbc, 0xac,
+			0xfe, 0x13, 0x5f, 0xda, 0x99, 0x97, 0x4c, 0x06, 0xb7,
+			0xb0, 0x1f, 0xae, 0x14, 0x9a,
+		},
+	},
+	{
+		3,
+		[]byte("passwordy\x00PASSWORD\x00"),
+		[]byte("salty\x00SALT\x00"),
+		[]byte{
+			0x7f, 0x31, 0x0b, 0xd3, 0xe7, 0x8c, 0x32, 0x80, 0xc5,
+			0x9c, 0xe4, 0x59, 0x52, 0x11, 0xa2, 0x92, 0x8e, 0x8d,
+			0x4e, 0xc7, 0x44, 0xc1, 0xed, 0x2e, 0xfc, 0x9f, 0x76,
+			0x4e, 0x33, 0x88, 0xe0, 0xad,
+		},
+	},
+	{
+		// See http://thread.gmane.org/gmane.os.openbsd.bugs/20542
+		8,
+		[]byte("секретное слово"),
+		[]byte("посолить немножко"),
+		[]byte{
+			0x8d, 0xf4, 0x3f, 0xc6, 0xfe, 0x13, 0x1f, 0xc4, 0x7f,
+			0x0c, 0x9e, 0x39, 0x22, 0x4b, 0xd9, 0x4c, 0x70, 0xb6,
+			0xfc, 0xc8, 0xee, 0x81, 0x35, 0xfa, 0xdd, 0xf6, 0x11,
+			0x56, 0xe6, 0xcb, 0x27, 0x33, 0xea, 0x76, 0x5f, 0x31,
+			0x5a, 0x3e, 0x1e, 0x4a, 0xfc, 0x35, 0xbf, 0x86, 0x87,
+			0xd1, 0x89, 0x25, 0x4c, 0x1e, 0x05, 0xa6, 0xfe, 0x80,
+			0xc0, 0x61, 0x7f, 0x91, 0x83, 0xd6, 0x72, 0x60, 0xd6,
+			0xa1, 0x15, 0xc6, 0xc9, 0x4e, 0x36, 0x03, 0xe2, 0x30,
+			0x3f, 0xbb, 0x43, 0xa7, 0x6a, 0x64, 0x52, 0x3f, 0xfd,
+			0xa6, 0x86, 0xb1, 0xd4, 0x51, 0x85, 0x43,
+		},
+	},
+}
+
+func TestKey(t *testing.T) {
+	for i, v := range golden {
+		k, err := Key(v.password, v.salt, v.rounds, len(v.result))
+		if err != nil {
+			t.Errorf("%d: %s", i, err)
+			continue
+		}
+		if !bytes.Equal(k, v.result) {
+			t.Errorf("%d: expected\n%x\n, got\n%x\n", i, v.result, k)
+		}
+	}
+}
+
+func TestBcryptHash(t *testing.T) {
+	good := []byte{
+		0x87, 0x90, 0x48, 0x70, 0xee, 0xf9, 0xde, 0xdd, 0xf8, 0xe7,
+		0x61, 0x1a, 0x14, 0x01, 0x06, 0xe6, 0xaa, 0xf1, 0xa3, 0x63,
+		0xd9, 0xa2, 0xc5, 0x04, 0xdb, 0x35, 0x64, 0x43, 0x72, 0x1e,
+		0xb5, 0x55,
+	}
+	var pass, salt [64]byte
+	var result [32]byte
+	for i := 0; i < 64; i++ {
+		pass[i] = byte(i)
+		salt[i] = byte(i + 64)
+	}
+	bcryptHash(result[:], pass[:], salt[:])
+	if !bytes.Equal(result[:], good) {
+		t.Errorf("expected %x, got %x", good, result)
+	}
+}
+
+func BenchmarkKey(b *testing.B) {
+	pass := []byte("password")
+	salt := []byte("salt")
+	for i := 0; i < b.N; i++ {
+		Key(pass, salt, 10, 32)
+	}
+}
diff --git a/ssh/keys.go b/ssh/keys.go
index c148ad4..5377ec8 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -7,6 +7,8 @@
 import (
 	"bytes"
 	"crypto"
+	"crypto/aes"
+	"crypto/cipher"
 	"crypto/dsa"
 	"crypto/ecdsa"
 	"crypto/elliptic"
@@ -25,6 +27,7 @@
 	"strings"
 
 	"golang.org/x/crypto/ed25519"
+	"golang.org/x/crypto/ssh/internal/bcrypt_pbkdf"
 )
 
 // These constants represent the algorithm names for key types supported by this
@@ -1122,21 +1125,25 @@
 	case "DSA PRIVATE KEY":
 		return ParseDSAPrivateKey(block.Bytes)
 	case "OPENSSH PRIVATE KEY":
-		return parseOpenSSHPrivateKey(block.Bytes)
+		return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey)
 	default:
 		return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
 	}
 }
 
 // ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
-// passphrase from a PEM encoded private key. If wrong passphrase, return
-// x509.IncorrectPasswordError.
+// passphrase from a PEM encoded private key. If the passphrase is wrong, it
+// will return x509.IncorrectPasswordError.
 func ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase []byte) (interface{}, error) {
 	block, _ := pem.Decode(pemBytes)
 	if block == nil {
 		return nil, errors.New("ssh: no key found")
 	}
 
+	if block.Type == "OPENSSH PRIVATE KEY" {
+		return parseOpenSSHPrivateKey(block.Bytes, passphraseProtectedOpenSSHKey(passphrase))
+	}
+
 	if !encryptedBlock(block) || !x509.IsEncryptedPEMBlock(block) {
 		return nil, errors.New("ssh: not an encrypted key")
 	}
@@ -1193,9 +1200,60 @@
 	}, nil
 }
 
-// Implemented based on the documentation at
-// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
-func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
+func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) {
+	if kdfName != "none" || cipherName != "none" {
+		return nil, &PassphraseMissingError{}
+	}
+	if kdfOpts != "" {
+		return nil, errors.New("ssh: invalid openssh private key")
+	}
+	return privKeyBlock, nil
+}
+
+func passphraseProtectedOpenSSHKey(passphrase []byte) openSSHDecryptFunc {
+	return func(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) {
+		if kdfName == "none" || cipherName == "none" {
+			return nil, errors.New("ssh: key is not password protected")
+		}
+		if kdfName != "bcrypt" {
+			return nil, fmt.Errorf("ssh: unknown KDF %q, only supports %q", kdfName, "bcrypt")
+		}
+
+		var opts struct {
+			Salt   string
+			Rounds uint32
+		}
+		if err := Unmarshal([]byte(kdfOpts), &opts); err != nil {
+			return nil, err
+		}
+
+		k, err := bcrypt_pbkdf.Key(passphrase, []byte(opts.Salt), int(opts.Rounds), 32+16)
+		if err != nil {
+			return nil, err
+		}
+		key, iv := k[:32], k[32:]
+
+		if cipherName != "aes256-ctr" {
+			return nil, fmt.Errorf("ssh: unknown cipher %q, only supports %q", cipherName, "aes256-ctr")
+		}
+		c, err := aes.NewCipher(key)
+		if err != nil {
+			return nil, err
+		}
+		ctr := cipher.NewCTR(c, iv)
+		ctr.XORKeyStream(privKeyBlock, privKeyBlock)
+
+		return privKeyBlock, nil
+	}
+}
+
+type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
+
+// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
+// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
+// as the decrypt function to parse an unencrypted private key. See
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
+func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
 	const magic = "openssh-key-v1\x00"
 	if len(key) < len(magic) || string(key[:len(magic)]) != magic {
 		return nil, errors.New("ssh: invalid openssh private key format")
@@ -1214,9 +1272,22 @@
 	if err := Unmarshal(remaining, &w); err != nil {
 		return nil, err
 	}
+	if w.NumKeys != 1 {
+		// We only support single key files, and so does OpenSSH.
+		// https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171
+		return nil, errors.New("ssh: multi-key files are not supported")
+	}
 
-	if w.KdfName != "none" || w.CipherName != "none" {
-		return nil, errors.New("ssh: cannot decode encrypted private keys")
+	privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock)
+	if err != nil {
+		if err, ok := err.(*PassphraseMissingError); ok {
+			pub, errPub := ParsePublicKey(w.PubKey)
+			if errPub != nil {
+				return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub)
+			}
+			err.PublicKey = pub
+		}
+		return nil, err
 	}
 
 	pk1 := struct {
@@ -1226,12 +1297,11 @@
 		Rest    []byte `ssh:"rest"`
 	}{}
 
-	if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
-		return nil, err
-	}
-
-	if pk1.Check1 != pk1.Check2 {
-		return nil, errors.New("ssh: checkint mismatch")
+	if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 {
+		if w.CipherName != "none" {
+			return nil, x509.IncorrectPasswordError
+		}
+		return nil, errors.New("ssh: malformed OpenSSH key")
 	}
 
 	// we only handle ed25519 and rsa keys currently
@@ -1253,10 +1323,8 @@
 			return nil, err
 		}
 
-		for i, b := range key.Pad {
-			if int(b) != i+1 {
-				return nil, errors.New("ssh: padding not as expected")
-			}
+		if err := checkOpenSSHKeyPadding(key.Pad); err != nil {
+			return nil, err
 		}
 
 		pk := &rsa.PrivateKey{
@@ -1291,10 +1359,8 @@
 			return nil, errors.New("ssh: private key unexpected length")
 		}
 
-		for i, b := range key.Pad {
-			if int(b) != i+1 {
-				return nil, errors.New("ssh: padding not as expected")
-			}
+		if err := checkOpenSSHKeyPadding(key.Pad); err != nil {
+			return nil, err
 		}
 
 		pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
@@ -1305,6 +1371,15 @@
 	}
 }
 
+func checkOpenSSHKeyPadding(pad []byte) error {
+	for i, b := range pad {
+		if int(b) != i+1 {
+			return errors.New("ssh: padding not as expected")
+		}
+	}
+	return nil
+}
+
 // FingerprintLegacyMD5 returns the user presentation of the key's
 // fingerprint as described by RFC 4716 section 4.
 func FingerprintLegacyMD5(pubKey PublicKey) string {
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go
index 0df38cd..72cfaa0 100644
--- a/ssh/testdata/keys.go
+++ b/ssh/testdata/keys.go
@@ -225,6 +225,21 @@
 -----END DSA PRIVATE KEY-----
 `),
 	},
+
+	2: {
+		Name:              "ed25519-encrypted",
+		EncryptionKey:     "password",
+		IncludesPublicKey: true,
+		PEMBytes: []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
+b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABDKj29BlC
+ocEWuVhQ94/RjoAAAAEAAAAAEAAAAzAAAAC3NzaC1lZDI1NTE5AAAAIIw1gSurPTDwZidA
+2AIjQZgoQi3IFn9jBtFdP10/Jj7DAAAAoFGkQbB2teSU7ikUsnc7ct2aH3pitM359lNVUh
+7DQbJWMjbQFbrBYyDJP+ALj1/RZmP2yoIf7/wr99q53/pm28Xp1gGP5V2RGRJYCA6kgFIH
+xdB6KEw1Ce7Bz8JaDIeagAGd3xtQTH3cuuleVxCZZnk9NspsPxigADKCls/RUiK7F+z3Qf
+Lvs9+PH8nIuhFMYZgo3liqZbVS5z4Fqhyzyq4=
+-----END OPENSSH PRIVATE KEY-----
+`),
+	},
 }
 
 // SKData contains a list of PubKeys backed by U2F/FIDO2 Security Keys and their test data.