ssh: reject unencrypted keys from ParsePrivateKeyWithPassphrase

The behavior of ParsePrivateKeyWithPassphrase when the key is
unencrypted is unspecified. Currently, it just parses them like
ParsePrivateKey, which is unlikely to be what anyone wants: for us to
ignore a passphrase that they explicitly passed. It also makes the
implementation of encrypted OpenSSH keys in the next CL more confused.

Instead, make ParsePrivateKey return a PassphraseNeededError, so the
application logic can be ParsePrivateKey -> detect encrypted key ->
obtain passphrase -> ParsePrivateKeyWithPassphrase. That error will also
let us return the public key for OpenSSH keys.

Change-Id: Ife4fb2499ae538bef36e353adf9bc8e902662386
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/207599
Run-TryBot: Filippo Valsorda <filippo@golang.org>
Run-TryBot: Han-Wen Nienhuys <hanwen@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Han-Wen Nienhuys <hanwen@google.com>
diff --git a/ssh/keys.go b/ssh/keys.go
index 1b536a5..c148ad4 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -1055,7 +1055,8 @@
 }
 
 // ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
-// the same keys as ParseRawPrivateKey.
+// the same keys as ParseRawPrivateKey. If the private key is encrypted, it
+// will return a PassphraseMissingError.
 func ParsePrivateKey(pemBytes []byte) (Signer, error) {
 	key, err := ParseRawPrivateKey(pemBytes)
 	if err != nil {
@@ -1068,8 +1069,8 @@
 // ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private
 // key and passphrase. It supports the same keys as
 // ParseRawPrivateKeyWithPassphrase.
-func ParsePrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (Signer, error) {
-	key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase)
+func ParsePrivateKeyWithPassphrase(pemBytes, passphrase []byte) (Signer, error) {
+	key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase)
 	if err != nil {
 		return nil, err
 	}
@@ -1085,8 +1086,21 @@
 	return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED")
 }
 
+// A PassphraseMissingError indicates that parsing this private key requires a
+// passphrase. Use ParsePrivateKeyWithPassphrase.
+type PassphraseMissingError struct {
+	// PublicKey will be set if the private key format includes an unencrypted
+	// public key along with the encrypted private key.
+	PublicKey PublicKey
+}
+
+func (*PassphraseMissingError) Error() string {
+	return "ssh: this private key is passphrase protected"
+}
+
 // ParseRawPrivateKey returns a private key from a PEM encoded private key. It
-// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys.
+// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys. If the
+// private key is encrypted, it will return a PassphraseMissingError.
 func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
 	block, _ := pem.Decode(pemBytes)
 	if block == nil {
@@ -1094,7 +1108,7 @@
 	}
 
 	if encryptedBlock(block) {
-		return nil, errors.New("ssh: cannot decode encrypted private keys")
+		return nil, &PassphraseMissingError{}
 	}
 
 	switch block.Type {
@@ -1117,24 +1131,22 @@
 // ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
 // passphrase from a PEM encoded private key. If wrong passphrase, return
 // x509.IncorrectPasswordError.
-func ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (interface{}, error) {
+func ParseRawPrivateKeyWithPassphrase(pemBytes, passphrase []byte) (interface{}, error) {
 	block, _ := pem.Decode(pemBytes)
 	if block == nil {
 		return nil, errors.New("ssh: no key found")
 	}
-	buf := block.Bytes
 
-	if encryptedBlock(block) {
-		if x509.IsEncryptedPEMBlock(block) {
-			var err error
-			buf, err = x509.DecryptPEMBlock(block, passPhrase)
-			if err != nil {
-				if err == x509.IncorrectPasswordError {
-					return nil, err
-				}
-				return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
-			}
+	if !encryptedBlock(block) || !x509.IsEncryptedPEMBlock(block) {
+		return nil, errors.New("ssh: not an encrypted key")
+	}
+
+	buf, err := x509.DecryptPEMBlock(block, passphrase)
+	if err != nil {
+		if err == x509.IncorrectPasswordError {
+			return nil, err
 		}
+		return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
 	}
 
 	switch block.Type {
@@ -1144,8 +1156,6 @@
 		return x509.ParseECPrivateKey(buf)
 	case "DSA PRIVATE KEY":
 		return ParseDSAPrivateKey(buf)
-	case "OPENSSH PRIVATE KEY":
-		return parseOpenSSHPrivateKey(buf)
 	default:
 		return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
 	}
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index 23a3566..d64ef73 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -179,44 +179,45 @@
 	}
 }
 
-// See Issue https://github.com/golang/go/issues/6650.
-func TestParseEncryptedPrivateKeysFails(t *testing.T) {
-	const wantSubstring = "encrypted"
-	for i, tt := range testdata.PEMEncryptedKeys {
-		_, err := ParsePrivateKey(tt.PEMBytes)
-		if err == nil {
-			t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name)
-			continue
-		}
-
-		if !strings.Contains(err.Error(), wantSubstring) {
-			t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring)
-		}
-	}
-}
-
-// Parse encrypted private keys with passphrase
 func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
 	data := []byte("sign me")
 	for _, tt := range testdata.PEMEncryptedKeys {
-		s, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
-		if err != nil {
-			t.Fatalf("ParsePrivateKeyWithPassphrase returned error: %s", err)
-			continue
-		}
-		sig, err := s.Sign(rand.Reader, data)
-		if err != nil {
-			t.Fatalf("dsa.Sign: %v", err)
-		}
-		if err := s.PublicKey().Verify(data, sig); err != nil {
-			t.Errorf("Verify failed: %v", err)
-		}
-	}
+		t.Run(tt.Name, func(t *testing.T) {
+			_, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte("incorrect"))
+			if err != x509.IncorrectPasswordError {
+				t.Errorf("got %v want IncorrectPasswordError", err)
+			}
 
-	tt := testdata.PEMEncryptedKeys[0]
-	_, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte("incorrect"))
-	if err != x509.IncorrectPasswordError {
-		t.Fatalf("got %v want IncorrectPasswordError", err)
+			s, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
+			if err != nil {
+				t.Fatalf("ParsePrivateKeyWithPassphrase returned error: %s", err)
+			}
+
+			sig, err := s.Sign(rand.Reader, data)
+			if err != nil {
+				t.Fatalf("Signer.Sign: %v", err)
+			}
+			if err := s.PublicKey().Verify(data, sig); err != nil {
+				t.Errorf("Verify failed: %v", err)
+			}
+
+			_, err = ParsePrivateKey(tt.PEMBytes)
+			if err == nil {
+				t.Fatalf("ParsePrivateKey succeeded, expected an error")
+			}
+
+			if err, ok := err.(*PassphraseMissingError); !ok {
+				t.Errorf("got error %q, want PassphraseMissingError", err)
+			} else if tt.IncludesPublicKey {
+				if err.PublicKey == nil {
+					t.Fatalf("expected PassphraseMissingError.PublicKey not to be nil")
+				}
+				got, want := err.PublicKey.Marshal(), s.PublicKey().Marshal()
+				if !bytes.Equal(got, want) {
+					t.Errorf("error field %q doesn't match signer public key %q", got, want)
+				}
+			}
+		})
 	}
 }
 
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go
index 90181bc..0df38cd 100644
--- a/ssh/testdata/keys.go
+++ b/ssh/testdata/keys.go
@@ -164,9 +164,10 @@
 }
 
 var PEMEncryptedKeys = []struct {
-	Name          string
-	EncryptionKey string
-	PEMBytes      []byte
+	Name              string
+	EncryptionKey     string
+	IncludesPublicKey bool
+	PEMBytes          []byte
 }{
 	0: {
 		Name:          "rsa-encrypted",