go.crypto/ssh: Add certificate verification, step up support for authorized keys
R=agl, hanwen, jpsugar, dave
CC=golang-dev
https://golang.org/cl/14540051
diff --git a/ssh/certs.go b/ssh/certs.go
index 0772289..1023d8a 100644
--- a/ssh/certs.go
+++ b/ssh/certs.go
@@ -52,6 +52,13 @@
Signature *signature
}
+// validateOpenSSHCertV01Signature uses the cert's SignatureKey to verify that
+// the cert's Signature.Blob is the result of signing the cert bytes starting
+// from the algorithm string and going up to and including the SignatureKey.
+func validateOpenSSHCertV01Signature(cert *OpenSSHCertV01) bool {
+ return cert.SignatureKey.Verify(cert.BytesForSigning(), cert.Signature.Blob)
+}
+
var certAlgoNames = map[string]string{
KeyAlgoRSA: CertAlgoRSAv01,
KeyAlgoDSA: CertAlgoDSAv01,
@@ -71,6 +78,66 @@
panic("unknown cert algorithm")
}
+func (cert *OpenSSHCertV01) marshal(includeAlgo, includeSig bool) []byte {
+ algoName := cert.PublicKeyAlgo()
+ pubKey := cert.Key.Marshal()
+ sigKey := MarshalPublicKey(cert.SignatureKey)
+
+ var length int
+ if includeAlgo {
+ length += stringLength(len(algoName))
+ }
+ length += stringLength(len(cert.Nonce))
+ length += len(pubKey)
+ length += 8 // Length of Serial
+ length += 4 // Length of Type
+ length += stringLength(len(cert.KeyId))
+ length += lengthPrefixedNameListLength(cert.ValidPrincipals)
+ length += 8 // Length of ValidAfter
+ length += 8 // Length of ValidBefore
+ length += tupleListLength(cert.CriticalOptions)
+ length += tupleListLength(cert.Extensions)
+ length += stringLength(len(cert.Reserved))
+ length += stringLength(len(sigKey))
+ if includeSig {
+ length += signatureLength(cert.Signature)
+ }
+
+ ret := make([]byte, length)
+ r := ret
+ if includeAlgo {
+ r = marshalString(r, []byte(algoName))
+ }
+ r = marshalString(r, cert.Nonce)
+ copy(r, pubKey)
+ r = r[len(pubKey):]
+ r = marshalUint64(r, cert.Serial)
+ r = marshalUint32(r, cert.Type)
+ r = marshalString(r, []byte(cert.KeyId))
+ r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
+ r = marshalUint64(r, uint64(cert.ValidAfter.Unix()))
+ r = marshalUint64(r, uint64(cert.ValidBefore.Unix()))
+ r = marshalTupleList(r, cert.CriticalOptions)
+ r = marshalTupleList(r, cert.Extensions)
+ r = marshalString(r, cert.Reserved)
+ r = marshalString(r, sigKey)
+ if includeSig {
+ r = marshalSignature(r, cert.Signature)
+ }
+ if len(r) > 0 {
+ panic("ssh: internal error, marshaling certificate did not fill the entire buffer")
+ }
+ return ret
+}
+
+func (cert *OpenSSHCertV01) BytesForSigning() []byte {
+ return cert.marshal(true, false)
+}
+
+func (cert *OpenSSHCertV01) Marshal() []byte {
+ return cert.marshal(false, true)
+}
+
func (c *OpenSSHCertV01) PublicKeyAlgo() string {
algo, ok := certAlgoNames[c.Key.PublicKeyAlgo()]
if !ok {
@@ -110,7 +177,7 @@
return
}
- if cert.Type, in, ok = parseUint32(in); !ok || cert.Type != UserCert && cert.Type != HostCert {
+ if cert.Type, in, ok = parseUint32(in); !ok {
return
}
@@ -164,45 +231,6 @@
return cert, in, ok
}
-func (cert *OpenSSHCertV01) Marshal() []byte {
- pubKey := cert.Key.Marshal()
- sigKey := MarshalPublicKey(cert.SignatureKey)
-
- length := stringLength(len(cert.Nonce))
- length += len(pubKey)
- length += 8 // Length of Serial
- length += 4 // Length of Type
- length += stringLength(len(cert.KeyId))
- length += lengthPrefixedNameListLength(cert.ValidPrincipals)
- length += 8 // Length of ValidAfter
- length += 8 // Length of ValidBefore
- length += tupleListLength(cert.CriticalOptions)
- length += tupleListLength(cert.Extensions)
- length += stringLength(len(cert.Reserved))
- length += stringLength(len(sigKey))
- length += signatureLength(cert.Signature)
-
- ret := make([]byte, length)
- r := marshalString(ret, cert.Nonce)
- copy(r, pubKey)
- r = r[len(pubKey):]
- r = marshalUint64(r, cert.Serial)
- r = marshalUint32(r, cert.Type)
- r = marshalString(r, []byte(cert.KeyId))
- r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
- r = marshalUint64(r, uint64(cert.ValidAfter.Unix()))
- r = marshalUint64(r, uint64(cert.ValidBefore.Unix()))
- r = marshalTupleList(r, cert.CriticalOptions)
- r = marshalTupleList(r, cert.Extensions)
- r = marshalString(r, cert.Reserved)
- r = marshalString(r, sigKey)
- r = marshalSignature(r, cert.Signature)
- if len(r) > 0 {
- panic("internal error")
- }
- return ret
-}
-
func lengthPrefixedNameListLength(namelist []string) int {
length := 4 // length prefix for list
for _, name := range namelist {
@@ -320,6 +348,9 @@
return
}
- // TODO(hanwen): this is a bug; 'rest' gets swallowed.
- return parseSignatureBody(sigBytes)
+ out, sigBytes, ok = parseSignatureBody(sigBytes)
+ if !ok || len(sigBytes) > 0 {
+ return nil, nil, false
+ }
+ return
}
diff --git a/ssh/certs_test.go b/ssh/certs_test.go
new file mode 100644
index 0000000..3cec28e
--- /dev/null
+++ b/ssh/certs_test.go
@@ -0,0 +1,55 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "bytes"
+ "testing"
+)
+
+// Cert generated by ssh-keygen 6.0p1 Debian-4.
+// % ssh-keygen -s ca-key -I test user-key
+var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
+
+func TestParseCert(t *testing.T) {
+ authKeyBytes := []byte(exampleSSHCert)
+
+ key, _, _, rest, ok := ParseAuthorizedKey(authKeyBytes)
+ if !ok {
+ t.Fatalf("could not parse certificate")
+ }
+ if len(rest) > 0 {
+ t.Errorf("rest: got %q, want empty", rest)
+ }
+
+ if _, ok = key.(*OpenSSHCertV01); !ok {
+ t.Fatalf("got %#v, want *OpenSSHCertV01", key)
+ }
+
+ marshaled := MarshalAuthorizedKey(key)
+ // Before comparison, remove the trailing newline that
+ // MarshalAuthorizedKey adds.
+ marshaled = marshaled[:len(marshaled)-1]
+ if !bytes.Equal(authKeyBytes, marshaled) {
+ t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
+ }
+}
+
+func TestVerifyCert(t *testing.T) {
+ key, _, _, _, _ := ParseAuthorizedKey([]byte(exampleSSHCert))
+ validCert := key.(*OpenSSHCertV01)
+ if ok := validateOpenSSHCertV01Signature(validCert); !ok {
+ t.Error("Unable to validate certificate!")
+ }
+
+ invalidCert := &OpenSSHCertV01{
+ Key: rsaKey.PublicKey(),
+ SignatureKey: ecdsaKey.PublicKey(),
+ Signature: &signature{},
+ }
+ if ok := validateOpenSSHCertV01Signature(invalidCert); ok {
+ t.Error("Invalid cert signature passed validation!")
+ }
+}
diff --git a/ssh/keys.go b/ssh/keys.go
index fa1e236..b41fefc 100644
--- a/ssh/keys.go
+++ b/ssh/keys.go
@@ -102,22 +102,8 @@
continue
}
- field := string(in[:i])
- switch field {
- case KeyAlgoRSA, KeyAlgoDSA:
- out, comment, ok = parseAuthorizedKey(in[i:])
- if ok {
- return
- }
- case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
- // We don't support these keys.
- in = rest
- continue
- case CertAlgoRSAv01, CertAlgoDSAv01,
- CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
- // We don't support these certificates.
- in = rest
- continue
+ if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
+ return
}
// No key type recognised. Maybe there's an options field at
@@ -157,14 +143,9 @@
continue
}
- field = string(in[:i])
- switch field {
- case KeyAlgoRSA, KeyAlgoDSA:
- out, comment, ok = parseAuthorizedKey(in[i:])
- if ok {
- options = candidateOptions
- return
- }
+ if out, comment, ok = parseAuthorizedKey(in[i:]); ok {
+ options = candidateOptions
+ return
}
in = rest
diff --git a/ssh/keys_test.go b/ssh/keys_test.go
index f05843f..99eac94 100644
--- a/ssh/keys_test.go
+++ b/ssh/keys_test.go
@@ -1,19 +1,66 @@
package ssh
import (
- "bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
- "encoding/base64"
"reflect"
"strings"
"testing"
+ "time"
)
-var ecdsaKey Signer
+var (
+ ecdsaKey Signer
+ ecdsa384Key Signer
+ ecdsa521Key Signer
+ testCertKey Signer
+)
+
+type testSigner struct {
+ Signer
+ pub PublicKey
+}
+
+func (ts *testSigner) PublicKey() PublicKey {
+ if ts.pub != nil {
+ return ts.pub
+ }
+ return ts.Signer.PublicKey()
+}
+
+func init() {
+ raw256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ ecdsaKey, _ = NewSignerFromKey(raw256)
+
+ raw384, _ := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
+ ecdsa384Key, _ = NewSignerFromKey(raw384)
+
+ raw521, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+ ecdsa521Key, _ = NewSignerFromKey(raw521)
+
+ // Create a cert and sign it for use in tests.
+ testCert := &OpenSSHCertV01{
+ Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ Key: ecdsaKey.PublicKey(),
+ ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
+ ValidAfter: time.Now().Truncate(time.Second),
+ ValidBefore: time.Now().Truncate(time.Second).Add(time.Hour),
+ Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
+ SignatureKey: rsaKey.PublicKey(),
+ }
+ sigBytes, _ := rsaKey.Sign(rand.Reader, testCert.BytesForSigning())
+ testCert.Signature = &signature{
+ Format: testCert.SignatureKey.PublicKeyAlgo(),
+ Blob: sigBytes,
+ }
+ testCertKey = &testSigner{
+ Signer: ecdsaKey,
+ pub: testCert,
+ }
+}
func rawKey(pub PublicKey) interface{} {
switch k := pub.(type) {
@@ -23,12 +70,14 @@
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
+ case *OpenSSHCertV01:
+ return k
}
panic("unknown key type")
}
func TestKeyMarshalParse(t *testing.T) {
- keys := []Signer{rsaKey, dsaKey, ecdsaKey}
+ keys := []Signer{rsaKey, dsaKey, ecdsaKey, ecdsa384Key, ecdsa521Key, testCertKey}
for _, priv := range keys {
pub := priv.PublicKey()
roundtrip, rest, ok := ParsePublicKey(MarshalPublicKey(pub))
@@ -79,7 +128,7 @@
}
func TestKeySignVerify(t *testing.T) {
- keys := []Signer{rsaKey, dsaKey, ecdsaKey}
+ keys := []Signer{rsaKey, dsaKey, ecdsaKey, testCertKey}
for _, priv := range keys {
pub := priv.PublicKey()
@@ -164,35 +213,3 @@
t.Error("Verify failed.")
}
}
-
-func TestParseCert(t *testing.T) {
- // Cert generated by ssh-keygen 6.0p1 Debian-4.
- // % ssh-keygen -s ca-key -I test user-key
- b64data := "AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8="
-
- data, err := base64.StdEncoding.DecodeString(b64data)
- if err != nil {
- t.Fatal("base64.StdEncoding.DecodeString: ", err)
- }
- key, rest, ok := ParsePublicKey(data)
- if !ok {
- t.Fatalf("could not parse certificate")
- }
- if len(rest) > 0 {
- t.Errorf("rest: got %q, want empty", rest)
- }
- _, ok = key.(*OpenSSHCertV01)
- if !ok {
- t.Fatalf("got %#v, want *OpenSSHCertV01", key)
- }
-
- marshaled := MarshalPublicKey(key)
- if !bytes.Equal(data, marshaled) {
- t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, data)
- }
-}
-
-func init() {
- raw, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
- ecdsaKey, _ = NewSignerFromKey(raw)
-}