go.crypto/ssh: introduce Signer method, an abstraction of
private keys.
R=agl, jpsugar, jonathan.mark.pittman
CC=golang-dev
https://golang.org/cl/13338044
diff --git a/ssh/certs.go b/ssh/certs.go
index ddece44..53bc8fd 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -68,10 +68,6 @@
return algo
}
-func (c *OpenSSHCertV01) RawKey() interface{} {
- return c.Key.RawKey()
-}
-
func (c *OpenSSHCertV01) PrivateKeyAlgo() string {
return c.Key.PrivateKeyAlgo()
}
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go
index 4460b41..a726595 100644
--- a/ssh/client_auth_test.go
+++ b/ssh/client_auth_test.go
@@ -6,12 +6,7 @@
import (
"bytes"
- "crypto"
"crypto/dsa"
- "crypto/rsa"
- "crypto/x509"
- "encoding/pem"
- "errors"
"io"
"io/ioutil"
"math/big"
@@ -62,32 +57,23 @@
// keychain implements the ClientKeyring interface
type keychain struct {
- keys []interface{}
+ keys []Signer
}
func (k *keychain) Key(i int) (PublicKey, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
- switch key := k.keys[i].(type) {
- case *rsa.PrivateKey:
- return NewRSAPublicKey(&key.PublicKey), nil
- case *dsa.PrivateKey:
- return NewDSAPublicKey(&key.PublicKey), nil
- }
- panic("unknown key type")
+
+ return k.keys[i].PublicKey(), nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
- hashFunc := crypto.SHA1
- h := hashFunc.New()
- h.Write(data)
- digest := h.Sum(nil)
- switch key := k.keys[i].(type) {
- case *rsa.PrivateKey:
- return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
- }
- return nil, errors.New("ssh: unknown key type")
+ return k.keys[i].Sign(rand, data)
+}
+
+func (k *keychain) add(key Signer) {
+ k.keys = append(k.keys, key)
}
func (k *keychain) loadPEM(file string) error {
@@ -95,15 +81,11 @@
if err != nil {
return err
}
- block, _ := pem.Decode(buf)
- if block == nil {
- return errors.New("ssh: no key found")
- }
- r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ key, err := ParsePrivateKey(buf)
if err != nil {
return err
}
- k.keys = append(k.keys, r)
+ k.add(key)
return nil
}
@@ -126,8 +108,8 @@
// reused internally by tests
var (
- rsakey *rsa.PrivateKey
- dsakey *dsa.PrivateKey
+ rsaKey Signer
+ dsaKey Signer
clientKeychain = new(keychain)
clientPassword = password("tiger")
serverConfig = &ServerConfig{
@@ -135,8 +117,7 @@
return user == "testuser" && pass == string(clientPassword)
},
PublicKeyCallback: func(conn *ServerConn, user, algo string, pubkey []byte) bool {
- rsaKey := &clientKeychain.keys[0].(*rsa.PrivateKey).PublicKey
- key := NewRSAPublicKey(rsaKey)
+ key, _ := clientKeychain.Key(0)
expected := MarshalPublicKey(key)
algoname := key.PublicKeyAlgo()
return user == "testuser" && algo == algoname && bytes.Equal(pubkey, expected)
@@ -157,21 +138,26 @@
)
func init() {
- if err := serverConfig.SetRSAPrivateKey([]byte(testServerPrivateKey)); err != nil {
+ var err error
+ rsaKey, err = ParsePrivateKey([]byte(testServerPrivateKey))
+ if err != nil {
panic("unable to set private key: " + err.Error())
}
+ rawDSAKey := new(dsa.PrivateKey)
- block, _ := pem.Decode([]byte(testClientPrivateKey))
- rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
-
- clientKeychain.keys = append(clientKeychain.keys, rsakey)
- dsakey = new(dsa.PrivateKey)
// taken from crypto/dsa/dsa_test.go
- dsakey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16)
- dsakey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16)
- dsakey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16)
- dsakey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16)
- dsakey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16)
+ rawDSAKey.P, _ = new(big.Int).SetString("A9B5B793FB4785793D246BAE77E8FF63CA52F442DA763C440259919FE1BC1D6065A9350637A04F75A2F039401D49F08E066C4D275A5A65DA5684BC563C14289D7AB8A67163BFBF79D85972619AD2CFF55AB0EE77A9002B0EF96293BDD0F42685EBB2C66C327079F6C98000FBCB79AACDE1BC6F9D5C7B1A97E3D9D54ED7951FEF", 16)
+ rawDSAKey.Q, _ = new(big.Int).SetString("E1D3391245933D68A0714ED34BBCB7A1F422B9C1", 16)
+ rawDSAKey.G, _ = new(big.Int).SetString("634364FC25248933D01D1993ECABD0657CC0CB2CEED7ED2E3E8AECDFCDC4A25C3B15E9E3B163ACA2984B5539181F3EFF1A5E8903D71D5B95DA4F27202B77D2C44B430BB53741A8D59A8F86887525C9F2A6A5980A195EAA7F2FF910064301DEF89D3AA213E1FAC7768D89365318E370AF54A112EFBA9246D9158386BA1B4EEFDA", 16)
+ rawDSAKey.Y, _ = new(big.Int).SetString("32969E5780CFE1C849A1C276D7AEB4F38A23B591739AA2FE197349AEEBD31366AEE5EB7E6C6DDB7C57D02432B30DB5AA66D9884299FAA72568944E4EEDC92EA3FBC6F39F53412FBCC563208F7C15B737AC8910DBC2D9C9B8C001E72FDC40EB694AB1F06A5A2DBD18D9E36C66F31F566742F11EC0A52E9F7B89355C02FB5D32D2", 16)
+ rawDSAKey.X, _ = new(big.Int).SetString("5078D4D29795CBE76D3AACFE48C9AF0BCDBEE91A", 16)
+
+ dsaKey, err = NewSignerFromKey(rawDSAKey)
+ if err != nil {
+ panic("NewSignerFromKey: " + err.Error())
+ }
+ clientKeychain.add(rsaKey)
+ serverConfig.AddHostKey(rsaKey)
}
// newMockAuthServer creates a new Server bound to
@@ -287,7 +273,8 @@
// the mock server will only authenticate ssh-rsa keys
func TestClientAuthInvalidPublicKey(t *testing.T) {
kc := new(keychain)
- kc.keys = append(kc.keys, dsakey)
+
+ kc.add(dsaKey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
@@ -305,7 +292,8 @@
// the client should authenticate with the second key
func TestClientAuthRSAandDSA(t *testing.T) {
kc := new(keychain)
- kc.keys = append(kc.keys, dsakey, rsakey)
+ kc.add(dsaKey)
+ kc.add(rsaKey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
@@ -321,7 +309,7 @@
func TestClientHMAC(t *testing.T) {
kc := new(keychain)
- kc.keys = append(kc.keys, rsakey)
+ kc.add(rsaKey)
for _, mac := range DefaultMACOrder {
config := &ClientConfig{
User: "testuser",
@@ -343,7 +331,6 @@
// issue 4285.
func TestClientUnsupportedCipher(t *testing.T) {
kc := new(keychain)
- kc.keys = append(kc.keys, rsakey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
@@ -362,7 +349,6 @@
func TestClientUnsupportedKex(t *testing.T) {
kc := new(keychain)
- kc.keys = append(kc.keys, rsakey)
config := &ClientConfig{
User: "testuser",
Auth: []ClientAuth{
diff --git a/ssh/common.go b/ssh/common.go
index e536c1a..7e6e5dc 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -6,7 +6,6 @@
import (
"crypto"
- "crypto/elliptic"
"errors"
"fmt"
"math/big"
@@ -223,19 +222,6 @@
return c.MACs
}
-// ecHash returns the hash to match the given elliptic curve, see RFC
-// 5656, section 6.2.1
-func ecHash(curve elliptic.Curve) crypto.Hash {
- bitSize := curve.Params().BitSize
- switch {
- case bitSize <= 256:
- return crypto.SHA256
- case bitSize <= 384:
- return crypto.SHA384
- }
- return crypto.SHA512
-}
-
// serialize a signed slice according to RFC 4254 6.6. The name should
// be a key type name, rather than a cert type name.
func serializeSignature(name string, sig []byte) []byte {
diff --git a/ssh/example_test.go b/ssh/example_test.go
index 715afb3..a88a677 100644
--- a/ssh/example_test.go
+++ b/ssh/example_test.go
@@ -23,14 +23,18 @@
},
}
- pemBytes, err := ioutil.ReadFile("id_rsa")
+ privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
panic("Failed to load private key")
}
- if err = config.SetRSAPrivateKey(pemBytes); err != nil {
+
+ private, err := ParsePrivateKey(privateBytes)
+ if err != nil {
panic("Failed to parse private key")
}
+ config.AddHostKey(private)
+
// Once a ServerConfig has been configured, connections can be
// accepted.
listener, err := Listen("tcp", "0.0.0.0:2022", config)
diff --git a/ssh/keys.go b/ssh/keys.go
index c135d3a..9738694 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -11,7 +11,12 @@
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
+ "crypto/x509"
"encoding/base64"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "io"
"math/big"
)
@@ -209,12 +214,17 @@
// Verify that sig is a signature on the given data using this
// key. This function will hash the data appropriately first.
Verify(data []byte, sigBlob []byte) bool
-
- // RawKey returns the underlying object, eg. *rsa.PublicKey.
- RawKey() interface{}
}
-// TODO(hanwen): define PrivateKey too.
+// A Signer is can create signatures that verify against a public key.
+type Signer interface {
+ // PublicKey returns an associated PublicKey instance.
+ PublicKey() PublicKey
+
+ // Sign returns raw signature for the given data. This method
+ // will apply the hash specified for the keytype to the data.
+ Sign(rand io.Reader, data []byte) ([]byte, error)
+}
type rsaPublicKey rsa.PublicKey
@@ -223,11 +233,7 @@
}
func (r *rsaPublicKey) PublicKeyAlgo() string {
- return "ssh-rsa"
-}
-
-func (r *rsaPublicKey) RawKey() interface{} {
- return (*rsa.PublicKey)(r)
+ return r.PrivateKeyAlgo()
}
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
@@ -250,7 +256,7 @@
}
ok = true
- return NewRSAPublicKey(key), in, ok
+ return (*rsaPublicKey)(key), in, ok
}
func (r *rsaPublicKey) Marshal() []byte {
@@ -273,8 +279,19 @@
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig) == nil
}
-func NewRSAPublicKey(k *rsa.PublicKey) PublicKey {
- return (*rsaPublicKey)(k)
+type rsaPrivateKey struct {
+ *rsa.PrivateKey
+}
+
+func (r *rsaPrivateKey) PublicKey() PublicKey {
+ return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
+}
+
+func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+ h := crypto.SHA1.New()
+ h.Write(data)
+ digest := h.Sum(nil)
+ return rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
}
type dsaPublicKey dsa.PublicKey
@@ -282,11 +299,9 @@
func (r *dsaPublicKey) PrivateKeyAlgo() string {
return "ssh-dss"
}
+
func (r *dsaPublicKey) PublicKeyAlgo() string {
- return "ssh-dss"
-}
-func (r *dsaPublicKey) RawKey() interface{} {
- return (*dsa.PublicKey)(r)
+ return r.PrivateKeyAlgo()
}
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
@@ -310,7 +325,7 @@
}
ok = true
- return NewDSAPublicKey(key), in, ok
+ return (*dsaPublicKey)(key), in, ok
}
func (r *dsaPublicKey) Marshal() []byte {
@@ -347,21 +362,33 @@
return dsa.Verify((*dsa.PublicKey)(k), digest, r, s)
}
-func NewDSAPublicKey(k *dsa.PublicKey) PublicKey {
- return (*dsaPublicKey)(k)
+type dsaPrivateKey struct {
+ *dsa.PrivateKey
+}
+
+func (k *dsaPrivateKey) PublicKey() PublicKey {
+ return (*dsaPublicKey)(&k.PrivateKey.PublicKey)
+}
+
+func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+ h := crypto.SHA1.New()
+ h.Write(data)
+ digest := h.Sum(nil)
+ r, s, err := dsa.Sign(rand, k.PrivateKey, digest)
+ if err != nil {
+ return nil, err
+ }
+
+ sig := make([]byte, 40)
+ copy(sig[:20], r.Bytes())
+ copy(sig[20:], s.Bytes())
+ return sig, nil
}
type ecdsaPublicKey ecdsa.PublicKey
-func NewECDSAPublicKey(k *ecdsa.PublicKey) PublicKey {
- return (*ecdsaPublicKey)(k)
-}
-func (r *ecdsaPublicKey) RawKey() interface{} {
- return (*ecdsa.PublicKey)(r)
-}
-
func (key *ecdsaPublicKey) PrivateKeyAlgo() string {
- return "ecdh-sha2-" + key.nistID()
+ return "ecdsa-sha2-" + key.nistID()
}
func (key *ecdsaPublicKey) nistID() string {
@@ -376,29 +403,25 @@
panic("ssh: unsupported ecdsa key size")
}
-// RFC 5656, section 6.2.1 (for ECDSA).
-func (key *ecdsaPublicKey) hash() crypto.Hash {
- switch key.Params().BitSize {
- case 256:
+func supportedEllipticCurve(curve elliptic.Curve) bool {
+ return (curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521())
+}
+
+// ecHash returns the hash to match the given elliptic curve, see RFC
+// 5656, section 6.2.1
+func ecHash(curve elliptic.Curve) crypto.Hash {
+ bitSize := curve.Params().BitSize
+ switch {
+ case bitSize <= 256:
return crypto.SHA256
- case 384:
+ case bitSize <= 384:
return crypto.SHA384
- case 521:
- return crypto.SHA512
}
- panic("ssh: unsupported ecdsa key size")
+ return crypto.SHA512
}
func (key *ecdsaPublicKey) PublicKeyAlgo() string {
- switch key.Params().BitSize {
- case 256:
- return KeyAlgoECDSA256
- case 384:
- return KeyAlgoECDSA384
- case 521:
- return KeyAlgoECDSA521
- }
- panic("ssh: unsupported ecdsa key size")
+ return key.PrivateKeyAlgo()
}
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
@@ -432,7 +455,7 @@
ok = false
return
}
- return NewECDSAPublicKey(key), in, ok
+ return (*ecdsaPublicKey)(key), in, ok
}
func (key *ecdsaPublicKey) Marshal() []byte {
@@ -450,7 +473,7 @@
}
func (key *ecdsaPublicKey) Verify(data []byte, sigBlob []byte) bool {
- h := key.hash().New()
+ h := ecHash(key.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@@ -468,3 +491,100 @@
}
return ecdsa.Verify((*ecdsa.PublicKey)(key), digest, r, s)
}
+
+type ecdsaPrivateKey struct {
+ *ecdsa.PrivateKey
+}
+
+func (k *ecdsaPrivateKey) PublicKey() PublicKey {
+ return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
+}
+
+func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) ([]byte, error) {
+ h := ecHash(k.PrivateKey.PublicKey.Curve).New()
+ h.Write(data)
+ digest := h.Sum(nil)
+ r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest)
+ if err != nil {
+ return nil, err
+ }
+
+ sig := make([]byte, intLength(r)+intLength(s))
+ rest := marshalInt(sig, r)
+ marshalInt(rest, s)
+ return sig, nil
+}
+
+// NewPrivateKey takes a pointer to rsa, dsa or ecdsa PrivateKey
+// returns a corresponding Signer instance. EC keys should use P256,
+// P384 or P521.
+func NewSignerFromKey(k interface{}) (Signer, error) {
+ var sshKey Signer
+ switch t := k.(type) {
+ case *rsa.PrivateKey:
+ sshKey = &rsaPrivateKey{t}
+ case *dsa.PrivateKey:
+ sshKey = &dsaPrivateKey{t}
+ case *ecdsa.PrivateKey:
+ if !supportedEllipticCurve(t.Curve) {
+ return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
+ }
+
+ sshKey = &ecdsaPrivateKey{t}
+ default:
+ return nil, fmt.Errorf("ssh: unsupported key type %T", k)
+ }
+ return sshKey, nil
+}
+
+// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
+// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
+func NewPublicKey(k interface{}) (PublicKey, error) {
+ var sshKey PublicKey
+ switch t := k.(type) {
+ case *rsa.PublicKey:
+ sshKey = (*rsaPublicKey)(t)
+ case *ecdsa.PublicKey:
+ if !supportedEllipticCurve(t.Curve) {
+ return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
+ }
+ sshKey = (*ecdsaPublicKey)(t)
+ case *dsa.PublicKey:
+ sshKey = (*dsaPublicKey)(t)
+ default:
+ return nil, fmt.Errorf("ssh: unsupported key type %T", k)
+ }
+ return sshKey, nil
+}
+
+// ParsePublicKey parses a PEM encoded private key. Currently, only
+// PKCS#1, RSA and ECDSA private keys are supported.
+func ParsePrivateKey(pemBytes []byte) (Signer, error) {
+ block, _ := pem.Decode(pemBytes)
+ if block == nil {
+ return nil, errors.New("ssh: no key found")
+ }
+
+ var rawkey interface{}
+ switch block.Type {
+ case "RSA PRIVATE KEY":
+ rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ if err != nil {
+ return nil, err
+ }
+ rawkey = rsa
+ case "EC PRIVATE KEY":
+ ec, err := x509.ParseECPrivateKey(block.Bytes)
+ if err != nil {
+ return nil, err
+ }
+ rawkey = ec
+
+ // TODO(hanwen): find doc for format and implement PEM parsing
+ // for DSA keys.
+ default:
+ return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
+ }
+
+ return NewSignerFromKey(rawkey)
+}
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index b77cacb..fb3f21e 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -1,60 +1,138 @@
package ssh
import (
- "crypto"
"crypto/dsa"
+ "crypto/ecdsa"
+ "crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"reflect"
+ "strings"
"testing"
)
-func TestRSAMarshal(t *testing.T) {
- k0 := &rsakey.PublicKey
- k1 := NewRSAPublicKey(k0)
- k2, rest, ok := ParsePublicKey(MarshalPublicKey(k1))
- if !ok {
- t.Errorf("could not parse back Blob output")
+var ecdsaKey Signer
+
+func rawKey(pub PublicKey) interface{} {
+ switch k := pub.(type) {
+ case *rsaPublicKey:
+ return (*rsa.PublicKey)(k)
+ case *dsaPublicKey:
+ return (*dsa.PublicKey)(k)
+ case *ecdsaPublicKey:
+ return (*ecdsa.PublicKey)(k)
}
- if len(rest) > 0 {
- t.Errorf("trailing junk in RSA Blob() output")
- }
- if !reflect.DeepEqual(k0, k2.RawKey().(*rsa.PublicKey)) {
- t.Errorf("got %#v in roundtrip, want %#v", k2.RawKey(), k0)
+ panic("unknown key type")
+}
+
+func TestKeyMarshalParse(t *testing.T) {
+ keys := []Signer{rsaKey, dsaKey, ecdsaKey}
+ for _, priv := range keys {
+ pub := priv.PublicKey()
+ roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub))
+ if !ok {
+ t.Errorf("ParsePublicKey(%T) failed", pub)
+ }
+
+ if len(rest) > 0 {
+ t.Errorf("ParsePublicKey(%T): trailing junk", pub)
+ }
+
+ k1 := rawKey(pub)
+ k2 := rawKey(roundtrip)
+
+ if !reflect.DeepEqual(k1, k2) {
+ t.Errorf("got %#v in roundtrip, want %#v", k2, k1)
+ }
}
}
-func TestRSAKeyVerify(t *testing.T) {
- pub := NewRSAPublicKey(&rsakey.PublicKey)
-
- data := []byte("sign me")
- h := crypto.SHA1.New()
- h.Write(data)
- digest := h.Sum(nil)
-
- sig, err := rsa.SignPKCS1v15(rand.Reader, rsakey, crypto.SHA1, digest)
+func TestUnsupportedCurves(t *testing.T) {
+ raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
- t.Fatalf("SignPKCS1v15: %v", err)
+ t.Fatalf("GenerateKey: %v", err)
}
- if !pub.Verify(data, sig) {
- t.Errorf("publicKey.Verify failed")
+ if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") {
+ t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err)
+ }
+
+ if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") {
+ t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err)
}
}
-func TestDSAMarshal(t *testing.T) {
- k0 := &dsakey.PublicKey
- k1 := NewDSAPublicKey(k0)
- k2, rest, ok := ParsePublicKey(MarshalPublicKey(k1))
+func TestNewPublicKey(t *testing.T) {
+ keys := []Signer{rsaKey, dsaKey, ecdsaKey}
+ for _, k := range keys {
+ raw := rawKey(k.PublicKey())
+ pub, err := NewPublicKey(raw)
+ if err != nil {
+ t.Errorf("NewPublicKey(%#v): %v", raw, err)
+ }
+ if !reflect.DeepEqual(k.PublicKey(), pub) {
+ t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey())
+ }
+ }
+}
+
+func TestKeySignVerify(t *testing.T) {
+ keys := []Signer{rsaKey, dsaKey, ecdsaKey}
+ for _, priv := range keys {
+ pub := priv.PublicKey()
+
+ data := []byte("sign me")
+ sig, err := priv.Sign(rand.Reader, data)
+ if err != nil {
+ t.Fatalf("Sign(%T): %v", priv, err)
+ }
+
+ if !pub.Verify(data, sig) {
+ t.Errorf("publicKey.Verify(%T) failed", priv)
+ }
+ }
+}
+
+func TestParseRSAPrivateKey(t *testing.T) {
+ key, err := ParsePrivateKey([]byte(testServerPrivateKey))
+ if err != nil {
+ t.Fatalf("ParsePrivateKey: %v", err)
+ }
+
+ rsa, ok := key.(*rsaPrivateKey)
if !ok {
- t.Errorf("could not parse back Blob output")
+ t.Fatalf("got %T, want *rsa.PrivateKey", rsa)
}
- if len(rest) > 0 {
- t.Errorf("trailing junk in DSA Blob() output")
- }
- if !reflect.DeepEqual(k0, k2.RawKey().(*dsa.PublicKey)) {
- t.Errorf("got %#v in roundtrip, want %#v", k2.RawKey(), k0)
+
+ if err := rsa.Validate(); err != nil {
+ t.Errorf("Validate: %v", err)
}
}
-// TODO(hanwen): test for ECDSA marshal.
+func TestParseECPrivateKey(t *testing.T) {
+ // Taken from the data in test/ .
+ pem := []byte(`-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
+AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
+6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
+-----END EC PRIVATE KEY-----`)
+
+ key, err := ParsePrivateKey(pem)
+ if err != nil {
+ t.Fatalf("ParsePrivateKey: %v", err)
+ }
+
+ ecKey, ok := key.(*ecdsaPrivateKey)
+ if !ok {
+ t.Fatalf("got %T, want *ecdsaPrivateKey", ecKey)
+ }
+
+ if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) {
+ t.Fatalf("public key does not validate.")
+ }
+}
+
+func init() {
+ raw, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ ecdsaKey, _ = NewSignerFromKey(raw)
+}
diff --git a/ssh/server.go b/ssh/server.go
index dc5cd9f..ffc35dd 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -10,10 +10,7 @@
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
- "crypto/rsa"
- "crypto/x509"
"encoding/binary"
- "encoding/pem"
"errors"
"io"
"math/big"
@@ -24,11 +21,7 @@
)
type ServerConfig struct {
- rsa *rsa.PrivateKey
-
- // rsaSerialized is the serialized form of the public key that
- // corresponds to the private key held in the rsa field.
- rsaSerialized []byte
+ hostKeys []Signer
// Rand provides the source of entropy for key exchange. If Rand is
// nil, the cryptographic random reader in package crypto/rand will
@@ -69,21 +62,29 @@
return c.Rand
}
+// AddHostKey adds a private key as a host key. If an existing host
+// key exists with the same algorithm, it is overwritten.
+func (s *ServerConfig) AddHostKey(key Signer) {
+ for i, k := range s.hostKeys {
+ if k.PublicKey().PublicKeyAlgo() == key.PublicKey().PublicKeyAlgo() {
+ s.hostKeys[i] = key
+ return
+ }
+ }
+
+ s.hostKeys = append(s.hostKeys, key)
+}
+
// SetRSAPrivateKey sets the private key for a Server. A Server must have a
// private key configured in order to accept connections. The private key must
// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
// typically contains such a key.
func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) error {
- block, _ := pem.Decode(pemBytes)
- if block == nil {
- return errors.New("ssh: no key found")
- }
- rsa, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ priv, err := ParsePrivateKey(pemBytes)
if err != nil {
return err
}
- s.rsa = rsa
- s.rsaSerialized = MarshalPublicKey(NewRSAPublicKey(&rsa.PublicKey))
+ s.AddHostKey(priv)
return nil
}
@@ -141,7 +142,7 @@
// kexECDH performs Elliptic Curve Diffie-Hellman key agreement on a
// ServerConnection, as documented in RFC 5656, section 4.
-func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, hostKeyAlgo string) (result *kexResult, err error) {
+func (s *ServerConn) kexECDH(curve elliptic.Curve, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
packet, err := s.readPacket()
if err != nil {
return
@@ -169,7 +170,7 @@
return nil, err
}
- hostKeyBytes := s.config.rsaSerialized
+ hostKeyBytes := MarshalPublicKey(priv.PublicKey())
serializedEphKey := elliptic.Marshal(curve, ephKey.PublicKey.X, ephKey.PublicKey.Y)
@@ -192,7 +193,9 @@
H := h.Sum(nil)
- sig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+ // H is already a hash, but the hostkey signing will apply its
+ // own key specific hash algorithm.
+ sig, err := signAndMarshal(priv, s.config.rand(), H)
if err != nil {
return nil, err
}
@@ -247,7 +250,7 @@
}
// kexDH performs Diffie-Hellman key agreement on a ServerConnection.
-func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (result *kexResult, err error) {
+func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
packet, err := s.readPacket()
if err != nil {
return
@@ -268,7 +271,7 @@
return nil, err
}
- hostKeyBytes := s.config.rsaSerialized
+ hostKeyBytes := MarshalPublicKey(priv.PublicKey())
h := hashFunc.New()
writeString(h, magics.clientVersion)
@@ -285,7 +288,9 @@
H := h.Sum(nil)
- sig, err := s.serializedHostKeySignature(hostKeyAlgo, H)
+ // H is already a hash, but the hostkey signing will apply its
+ // own key specific hash algorithm.
+ sig, err := signAndMarshal(priv, s.config.rand(), H)
if err != nil {
return nil, err
}
@@ -306,25 +311,15 @@
}, nil
}
-// serializedHostKeySignature signs the hashed data, and serializes
-// the signature according to SSH conventions.
-func (s *ServerConn) serializedHostKeySignature(hostKeyAlgo string, hashed []byte) ([]byte, error) {
- var sig []byte
- switch hostKeyAlgo {
- case hostAlgoRSA:
- hashFunc := crypto.SHA1
- hh := hashFunc.New()
- hh.Write(hashed)
- var err error
- sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh.Sum(nil))
- if err != nil {
- return nil, err
- }
- default:
- return nil, errors.New("ssh: internal error")
+// signAndMarshal signs the data with the appropriate algorithm,
+// and serializes the result in SSH wire format.
+func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
+ sig, err := k.Sign(rand, data)
+ if err != nil {
+ return nil, err
}
- return serializeSignature(hostKeyAlgo, sig), nil
+ return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
// serverVersion is the fixed identification string that Server will use.
@@ -374,7 +369,6 @@
func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
serverKexInit := kexInitMsg{
KexAlgos: s.config.Crypto.kexes(),
- ServerHostKeyAlgos: supportedHostKeyAlgos,
CiphersClientServer: s.config.Crypto.ciphers(),
CiphersServerClient: s.config.Crypto.ciphers(),
MACsClientServer: s.config.Crypto.macs(),
@@ -382,8 +376,12 @@
CompressionClientServer: supportedCompressions,
CompressionServerClient: supportedCompressions,
}
- serverKexInitPacket := marshal(msgKexInit, serverKexInit)
+ for _, k := range s.config.hostKeys {
+ serverKexInit.ServerHostKeyAlgos = append(
+ serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKeyAlgo())
+ }
+ serverKexInitPacket := marshal(msgKexInit, serverKexInit)
if err = s.writePacket(serverKexInitPacket); err != nil {
return
}
@@ -410,6 +408,13 @@
}
}
+ var hostKey Signer
+ for _, k := range s.config.hostKeys {
+ if hostKeyAlgo == k.PublicKey().PublicKeyAlgo() {
+ hostKey = k
+ }
+ }
+
var magics handshakeMagics
magics.serverVersion = serverVersion[:len(serverVersion)-2]
magics.clientVersion = s.ClientVersion
@@ -419,17 +424,17 @@
var result *kexResult
switch kexAlgo {
case kexAlgoECDH256:
- result, err = s.kexECDH(elliptic.P256(), &magics, hostKeyAlgo)
+ result, err = s.kexECDH(elliptic.P256(), &magics, hostKey)
case kexAlgoECDH384:
- result, err = s.kexECDH(elliptic.P384(), &magics, hostKeyAlgo)
+ result, err = s.kexECDH(elliptic.P384(), &magics, hostKey)
case kexAlgoECDH521:
- result, err = s.kexECDH(elliptic.P521(), &magics, hostKeyAlgo)
+ result, err = s.kexECDH(elliptic.P521(), &magics, hostKey)
case kexAlgoDH14SHA1:
dhGroup14Once.Do(initDHGroup14)
- result, err = s.kexDH(dhGroup14, crypto.SHA1, &magics, hostKeyAlgo)
+ result, err = s.kexDH(dhGroup14, crypto.SHA1, &magics, hostKey)
case kexAlgoDH1SHA1:
dhGroup1Once.Do(initDHGroup1)
- result, err = s.kexDH(dhGroup1, crypto.SHA1, &magics, hostKeyAlgo)
+ result, err = s.kexDH(dhGroup1, crypto.SHA1, &magics, hostKey)
default:
err = errors.New("ssh: unexpected key exchange algorithm " + kexAlgo)
}
diff --git a/ssh/test/keys_test.go b/ssh/test/keys_test.go
index 363872c..b116422 100644
--- a/ssh/test/keys_test.go
+++ b/ssh/test/keys_test.go
@@ -1,8 +1,6 @@
package test
import (
- "crypto/x509"
- "encoding/pem"
"reflect"
"strings"
"testing"
@@ -171,16 +169,12 @@
}
func getTestPublicKey(t *testing.T) ssh.PublicKey {
- block, _ := pem.Decode([]byte(testClientPrivateKey))
- if block == nil {
- t.Fatalf("pem.Decode: %v", testClientPrivateKey)
- }
- priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ priv, err := ssh.ParsePrivateKey([]byte(testClientPrivateKey))
if err != nil {
- t.Fatalf("x509.ParsePKCS1PrivateKey: %v", err)
+ t.Fatalf("ParsePrivateKey: %v", err)
}
- return ssh.NewRSAPublicKey(&priv.PublicKey)
+ return priv.PublicKey()
}
func TestAuth(t *testing.T) {
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go
index 7ab5e22..6311876 100644
--- a/ssh/test/test_unix_test.go
+++ b/ssh/test/test_unix_test.go
@@ -10,12 +10,7 @@
import (
"bytes"
- "crypto"
- "crypto/dsa"
- "crypto/rsa"
- "crypto/x509"
- "encoding/pem"
- "errors"
+ "fmt"
"io"
"io/ioutil"
"log"
@@ -53,25 +48,23 @@
`
var (
- configTmpl template.Template
- rsakey *rsa.PrivateKey
- serializedHostKey []byte
+ configTmpl template.Template
+ privateKey ssh.Signer
+ hostKey ssh.Signer
)
func init() {
template.Must(configTmpl.Parse(sshd_config))
- block, _ := pem.Decode([]byte(testClientPrivateKey))
- rsakey, _ = x509.ParsePKCS1PrivateKey(block.Bytes)
- block, _ = pem.Decode([]byte(keys["ssh_host_rsa_key"]))
- if block == nil {
- panic("pem.Decode ssh_host_rsa_key")
- }
- priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ var err error
+ hostKey, err = ssh.ParsePrivateKey([]byte(keys["ssh_host_rsa_key"]))
if err != nil {
- panic("ParsePKCS1PrivateKey: " + err.Error())
+ panic("ParsePrivateKey: " + err.Error())
}
- serializedHostKey = ssh.MarshalPublicKey(ssh.NewRSAPublicKey(&priv.PublicKey))
+ privateKey, err = ssh.ParsePrivateKey([]byte(testClientPrivateKey))
+ if err != nil {
+ panic("ParsePrivateKey: " + err.Error())
+ }
}
type server struct {
@@ -106,26 +99,26 @@
keys map[string][]byte
}
-func (k *storedHostKey) Add(algo string, public []byte) {
+func (k *storedHostKey) Add(key ssh.PublicKey) {
if k.keys == nil {
k.keys = map[string][]byte{}
}
- k.keys[algo] = append([]byte(nil), public...)
+ k.keys[key.PublicKeyAlgo()] = append([]byte(nil), ssh.MarshalPublicKey(key)...)
}
func (k *storedHostKey) Check(addr string, remote net.Addr, algo string, key []byte) error {
if k.keys == nil || bytes.Compare(key, k.keys[algo]) != 0 {
- return errors.New("host key mismatch")
+ return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
}
return nil
}
func clientConfig() *ssh.ClientConfig {
keyChecker := storedHostKey{}
- keyChecker.Add("ssh-rsa", serializedHostKey)
+ keyChecker.Add(hostKey.PublicKey())
kc := new(keychain)
- kc.keys = append(kc.keys, rsakey)
+ kc.keys = append(kc.keys, privateKey)
config := &ssh.ClientConfig{
User: username(),
Auth: []ssh.ClientAuth{
@@ -261,34 +254,20 @@
}
}
-// keychain implements the ClientKeyring interface
+// keychain implements the ClientKeyring interface.
type keychain struct {
- keys []interface{}
+ keys []ssh.Signer
}
func (k *keychain) Key(i int) (ssh.PublicKey, error) {
if i < 0 || i >= len(k.keys) {
return nil, nil
}
- switch key := k.keys[i].(type) {
- case *rsa.PrivateKey:
- return ssh.NewRSAPublicKey(&key.PublicKey), nil
- case *dsa.PrivateKey:
- return ssh.NewDSAPublicKey(&key.PublicKey), nil
- }
- panic("unknown key type")
+ return k.keys[i].PublicKey(), nil
}
func (k *keychain) Sign(i int, rand io.Reader, data []byte) (sig []byte, err error) {
- hashFunc := crypto.SHA1
- h := hashFunc.New()
- h.Write(data)
- digest := h.Sum(nil)
- switch key := k.keys[i].(type) {
- case *rsa.PrivateKey:
- return rsa.SignPKCS1v15(rand, key, hashFunc, digest)
- }
- return nil, errors.New("ssh: unknown key type")
+ return k.keys[i].Sign(rand, data)
}
func (k *keychain) loadPEM(file string) error {
@@ -296,14 +275,10 @@
if err != nil {
return err
}
- block, _ := pem.Decode(buf)
- if block == nil {
- return errors.New("ssh: no key found")
- }
- r, err := x509.ParsePKCS1PrivateKey(block.Bytes)
+ key, err := ssh.ParsePrivateKey(buf)
if err != nil {
return err
}
- k.keys = append(k.keys, r)
+ k.keys = append(k.keys, key)
return nil
}