ssh: support for marshaling keys using the OpenSSH format

This adds methods to marshal private keys, encrypted and unencrypted
to the OpenSSH format.

Fixes golang/go#37132

Change-Id: I1a95301f789ce04858e6b147748c6e8b7700384b
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/218620
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Roland Shoemaker <roland@golang.org>
diff --git a/ssh/keys.go b/ssh/keys.go
index dac8ee7..1bf28d8 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -13,11 +13,13 @@
 	"crypto/ecdsa"
 	"crypto/elliptic"
 	"crypto/md5"
+	"crypto/rand"
 	"crypto/rsa"
 	"crypto/sha256"
 	"crypto/x509"
 	"encoding/asn1"
 	"encoding/base64"
+	"encoding/binary"
 	"encoding/hex"
 	"encoding/pem"
 	"errors"
@@ -295,6 +297,18 @@
 	return b.Bytes()
 }
 
+// MarshalPrivateKey returns a PEM block with the private key serialized in the
+// OpenSSH format.
+func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) {
+	return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler)
+}
+
+// MarshalPrivateKeyWithPassphrase returns a PEM block holding the encrypted
+// private key serialized in the OpenSSH format.
+func MarshalPrivateKeyWithPassphrase(key crypto.PrivateKey, comment string, passphrase []byte) (*pem.Block, error) {
+	return marshalOpenSSHPrivateKey(key, comment, passphraseProtectedOpenSSHMarshaler(passphrase))
+}
+
 // PublicKey represents a public key using an unspecified algorithm.
 //
 // Some PublicKeys provided by this package also implement CryptoPublicKey.
@@ -1241,28 +1255,106 @@
 	}
 }
 
+func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) {
+	key := generateOpenSSHPadding(privKeyBlock, 8)
+	return key, "none", "none", "", nil
+}
+
+func passphraseProtectedOpenSSHMarshaler(passphrase []byte) openSSHEncryptFunc {
+	return func(privKeyBlock []byte) ([]byte, string, string, string, error) {
+		salt := make([]byte, 16)
+		if _, err := rand.Read(salt); err != nil {
+			return nil, "", "", "", err
+		}
+
+		opts := struct {
+			Salt   []byte
+			Rounds uint32
+		}{salt, 16}
+
+		// Derive key to encrypt the private key block.
+		k, err := bcrypt_pbkdf.Key(passphrase, salt, int(opts.Rounds), 32+aes.BlockSize)
+		if err != nil {
+			return nil, "", "", "", err
+		}
+
+		// Add padding matching the block size of AES.
+		keyBlock := generateOpenSSHPadding(privKeyBlock, aes.BlockSize)
+
+		// Encrypt the private key using the derived secret.
+
+		dst := make([]byte, len(keyBlock))
+		key, iv := k[:32], k[32:]
+		block, err := aes.NewCipher(key)
+		if err != nil {
+			return nil, "", "", "", err
+		}
+
+		stream := cipher.NewCTR(block, iv)
+		stream.XORKeyStream(dst, keyBlock)
+
+		return dst, "aes256-ctr", "bcrypt", string(Marshal(opts)), nil
+	}
+}
+
+const privateKeyAuthMagic = "openssh-key-v1\x00"
+
 type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error)
+type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error)
+
+type openSSHEncryptedPrivateKey struct {
+	CipherName   string
+	KdfName      string
+	KdfOpts      string
+	NumKeys      uint32
+	PubKey       []byte
+	PrivKeyBlock []byte
+}
+
+type openSSHPrivateKey struct {
+	Check1  uint32
+	Check2  uint32
+	Keytype string
+	Rest    []byte `ssh:"rest"`
+}
+
+type openSSHRSAPrivateKey struct {
+	N       *big.Int
+	E       *big.Int
+	D       *big.Int
+	Iqmp    *big.Int
+	P       *big.Int
+	Q       *big.Int
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
+
+type openSSHEd25519PrivateKey struct {
+	Pub     []byte
+	Priv    []byte
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
+
+type openSSHECDSAPrivateKey struct {
+	Curve   string
+	Pub     []byte
+	D       *big.Int
+	Comment string
+	Pad     []byte `ssh:"rest"`
+}
 
 // parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt
 // function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used
 // as the decrypt function to parse an unencrypted private key. See
 // https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key.
 func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) {
-	const magic = "openssh-key-v1\x00"
-	if len(key) < len(magic) || string(key[:len(magic)]) != magic {
+	if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic {
 		return nil, errors.New("ssh: invalid openssh private key format")
 	}
-	remaining := key[len(magic):]
+	remaining := key[len(privateKeyAuthMagic):]
 
-	var w struct {
-		CipherName   string
-		KdfName      string
-		KdfOpts      string
-		NumKeys      uint32
-		PubKey       []byte
-		PrivKeyBlock []byte
-	}
-
+	var w openSSHEncryptedPrivateKey
 	if err := Unmarshal(remaining, &w); err != nil {
 		return nil, err
 	}
@@ -1284,13 +1376,7 @@
 		return nil, err
 	}
 
-	pk1 := struct {
-		Check1  uint32
-		Check2  uint32
-		Keytype string
-		Rest    []byte `ssh:"rest"`
-	}{}
-
+	var pk1 openSSHPrivateKey
 	if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 {
 		if w.CipherName != "none" {
 			return nil, x509.IncorrectPasswordError
@@ -1300,18 +1386,7 @@
 
 	switch pk1.Keytype {
 	case KeyAlgoRSA:
-		// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
-		key := struct {
-			N       *big.Int
-			E       *big.Int
-			D       *big.Int
-			Iqmp    *big.Int
-			P       *big.Int
-			Q       *big.Int
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHRSAPrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1337,13 +1412,7 @@
 
 		return pk, nil
 	case KeyAlgoED25519:
-		key := struct {
-			Pub     []byte
-			Priv    []byte
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHEd25519PrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1360,14 +1429,7 @@
 		copy(pk, key.Priv)
 		return &pk, nil
 	case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
-		key := struct {
-			Curve   string
-			Pub     []byte
-			D       *big.Int
-			Comment string
-			Pad     []byte `ssh:"rest"`
-		}{}
-
+		var key openSSHECDSAPrivateKey
 		if err := Unmarshal(pk1.Rest, &key); err != nil {
 			return nil, err
 		}
@@ -1415,6 +1477,131 @@
 	}
 }
 
+func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) {
+	var w openSSHEncryptedPrivateKey
+	var pk1 openSSHPrivateKey
+
+	// Random check bytes.
+	var check uint32
+	if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil {
+		return nil, err
+	}
+
+	pk1.Check1 = check
+	pk1.Check2 = check
+	w.NumKeys = 1
+
+	// Use a []byte directly on ed25519 keys.
+	if k, ok := key.(*ed25519.PrivateKey); ok {
+		key = *k
+	}
+
+	switch k := key.(type) {
+	case *rsa.PrivateKey:
+		E := new(big.Int).SetInt64(int64(k.PublicKey.E))
+		// Marshal public key:
+		// E and N are in reversed order in the public and private key.
+		pubKey := struct {
+			KeyType string
+			E       *big.Int
+			N       *big.Int
+		}{
+			KeyAlgoRSA,
+			E, k.PublicKey.N,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHRSAPrivateKey{
+			N:       k.PublicKey.N,
+			E:       E,
+			D:       k.D,
+			Iqmp:    k.Precomputed.Qinv,
+			P:       k.Primes[0],
+			Q:       k.Primes[1],
+			Comment: comment,
+		}
+		pk1.Keytype = KeyAlgoRSA
+		pk1.Rest = Marshal(key)
+	case ed25519.PrivateKey:
+		pub := make([]byte, ed25519.PublicKeySize)
+		priv := make([]byte, ed25519.PrivateKeySize)
+		copy(pub, k[32:])
+		copy(priv, k)
+
+		// Marshal public key.
+		pubKey := struct {
+			KeyType string
+			Pub     []byte
+		}{
+			KeyAlgoED25519, pub,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHEd25519PrivateKey{
+			Pub:     pub,
+			Priv:    priv,
+			Comment: comment,
+		}
+		pk1.Keytype = KeyAlgoED25519
+		pk1.Rest = Marshal(key)
+	case *ecdsa.PrivateKey:
+		var curve, keyType string
+		switch name := k.Curve.Params().Name; name {
+		case "P-256":
+			curve = "nistp256"
+			keyType = KeyAlgoECDSA256
+		case "P-384":
+			curve = "nistp384"
+			keyType = KeyAlgoECDSA384
+		case "P-521":
+			curve = "nistp521"
+			keyType = KeyAlgoECDSA521
+		default:
+			return nil, errors.New("ssh: unhandled elliptic curve " + name)
+		}
+
+		pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y)
+
+		// Marshal public key.
+		pubKey := struct {
+			KeyType string
+			Curve   string
+			Pub     []byte
+		}{
+			keyType, curve, pub,
+		}
+		w.PubKey = Marshal(pubKey)
+
+		// Marshal private key.
+		key := openSSHECDSAPrivateKey{
+			Curve:   curve,
+			Pub:     pub,
+			D:       k.D,
+			Comment: comment,
+		}
+		pk1.Keytype = keyType
+		pk1.Rest = Marshal(key)
+	default:
+		return nil, fmt.Errorf("ssh: unsupported key type %T", k)
+	}
+
+	var err error
+	// Add padding and encrypt the key if necessary.
+	w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1))
+	if err != nil {
+		return nil, err
+	}
+
+	b := Marshal(w)
+	block := &pem.Block{
+		Type:  "OPENSSH PRIVATE KEY",
+		Bytes: append([]byte(privateKeyAuthMagic), b...),
+	}
+	return block, nil
+}
+
 func checkOpenSSHKeyPadding(pad []byte) error {
 	for i, b := range pad {
 		if int(b) != i+1 {
@@ -1424,6 +1611,13 @@
 	return nil
 }
 
+func generateOpenSSHPadding(block []byte, blockSize int) []byte {
+	for i, l := 0, len(block); (l+i)%blockSize != 0; i++ {
+		block = append(block, byte(i+1))
+	}
+	return block
+}
+
 // FingerprintLegacyMD5 returns the user presentation of the key's
 // fingerprint as described by RFC 4716 section 4.
 func FingerprintLegacyMD5(pubKey PublicKey) string {
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index 334ef74..a8e216e 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -281,6 +281,74 @@
 	}
 }
 
+func TestMarshalPrivateKey(t *testing.T) {
+	tests := []struct {
+		name string
+	}{
+		{"rsa-openssh-format"},
+		{"ed25519"},
+		{"p256-openssh-format"},
+		{"p384-openssh-format"},
+		{"p521-openssh-format"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			expected, ok := testPrivateKeys[tt.name]
+			if !ok {
+				t.Fatalf("cannot find key %s", tt.name)
+			}
+
+			block, err := MarshalPrivateKey(expected, "test@golang.org")
+			if err != nil {
+				t.Fatalf("cannot marshal %s: %v", tt.name, err)
+			}
+
+			key, err := ParseRawPrivateKey(pem.EncodeToMemory(block))
+			if err != nil {
+				t.Fatalf("cannot parse %s: %v", tt.name, err)
+			}
+
+			if !reflect.DeepEqual(expected, key) {
+				t.Errorf("unexpected marshaled key %s", tt.name)
+			}
+		})
+	}
+}
+
+func TestMarshalPrivateKeyWithPassphrase(t *testing.T) {
+	tests := []struct {
+		name string
+	}{
+		{"rsa-openssh-format"},
+		{"ed25519"},
+		{"p256-openssh-format"},
+		{"p384-openssh-format"},
+		{"p521-openssh-format"},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			expected, ok := testPrivateKeys[tt.name]
+			if !ok {
+				t.Fatalf("cannot find key %s", tt.name)
+			}
+
+			block, err := MarshalPrivateKeyWithPassphrase(expected, "test@golang.org", []byte("test-passphrase"))
+			if err != nil {
+				t.Fatalf("cannot marshal %s: %v", tt.name, err)
+			}
+
+			key, err := ParseRawPrivateKeyWithPassphrase(pem.EncodeToMemory(block), []byte("test-passphrase"))
+			if err != nil {
+				t.Fatalf("cannot parse %s: %v", tt.name, err)
+			}
+
+			if !reflect.DeepEqual(expected, key) {
+				t.Errorf("unexpected marshaled key %s", tt.name)
+			}
+		})
+	}
+}
+
 type testAuthResult struct {
 	pubKey   PublicKey
 	options  []string