go.crypto/ssh: add client support for OpenSSH certificates
Refactor key parsing, marshaling, and serialization to be a bit more flexible
R=agl, dave, djm
CC=golang-dev
https://golang.org/cl/5650067
diff --git a/ssh/certs.go b/ssh/certs.go
new file mode 100644
index 0000000..59430ea
--- /dev/null
+++ b/ssh/certs.go
@@ -0,0 +1,306 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// References
+// [PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
+
+import (
+ "crypto/dsa"
+ "crypto/rsa"
+ "time"
+)
+
+// String constants in [PROTOCOL.certkeys] for certificate algorithm names.
+const (
+ hostAlgoRSACertV01 = "ssh-rsa-cert-v01@openssh.com"
+ hostAlgoDSACertV01 = "ssh-dss-cert-v01@openssh.com"
+ hostAlgoECDSA256CertV01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
+ hostAlgoECDSA384CertV01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
+ hostAlgoECDSA521CertV01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
+)
+
+// Certificate types are used to specify whether a certificate is for identification
+// of a user or a host. Current identities are defined in [PROTOCOL.certkeys].
+const (
+ UserCert = 1
+ HostCert = 2
+)
+
+type signature struct {
+ Format string
+ Blob []byte
+}
+
+type tuple struct {
+ Name string
+ Data string
+}
+
+// An OpenSSHCertV01 represents an OpenSSH certificate as defined in
+// [PROTOCOL.certkeys] rev 1.8. Supported formats include
+// ssh-rsa-cert-v01@openssh.com and ssh-dss-cert-v01@openssh.com.
+type OpenSSHCertV01 struct {
+ Nonce []byte
+ Key interface{} // rsa or dsa *PublicKey
+ Serial uint64
+ Type uint32
+ KeyId string
+ ValidPrincipals []string
+ ValidAfter, ValidBefore time.Time
+ CriticalOptions []tuple
+ Extensions []tuple
+ Reserved []byte
+ SignatureKey interface{} // rsa, dsa, or ecdsa *PublicKey
+ Signature *signature
+}
+
+func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
+ cert := new(OpenSSHCertV01)
+
+ if cert.Nonce, in, ok = parseString(in); !ok {
+ return
+ }
+
+ switch algo {
+ case hostAlgoRSACertV01:
+ var rsaPubKey *rsa.PublicKey
+ if rsaPubKey, in, ok = parseRSA(in); !ok {
+ return
+ }
+ cert.Key = rsaPubKey
+ case hostAlgoDSACertV01:
+ var dsaPubKey *dsa.PublicKey
+ if dsaPubKey, in, ok = parseDSA(in); !ok {
+ return
+ }
+ cert.Key = dsaPubKey
+ default:
+ return
+ }
+
+ if cert.Serial, in, ok = parseUint64(in); !ok {
+ return
+ }
+
+ if cert.Type, in, ok = parseUint32(in); !ok || cert.Type != UserCert && cert.Type != HostCert {
+ return
+ }
+
+ keyId, in, ok := parseString(in)
+ if !ok {
+ return
+ }
+ cert.KeyId = string(keyId)
+
+ if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok {
+ return
+ }
+
+ va, in, ok := parseUint64(in)
+ if !ok {
+ return
+ }
+ cert.ValidAfter = time.Unix(int64(va), 0)
+
+ vb, in, ok := parseUint64(in)
+ if !ok {
+ return
+ }
+ cert.ValidBefore = time.Unix(int64(vb), 0)
+
+ if cert.CriticalOptions, in, ok = parseTupleList(in); !ok {
+ return
+ }
+
+ if cert.Extensions, in, ok = parseTupleList(in); !ok {
+ return
+ }
+
+ if cert.Reserved, in, ok = parseString(in); !ok {
+ return
+ }
+
+ sigKey, in, ok := parseString(in)
+ if !ok {
+ return
+ }
+ if cert.SignatureKey, _, ok = parsePubKey(sigKey); !ok {
+ return
+ }
+
+ if cert.Signature, in, ok = parseSignature(in); !ok {
+ return
+ }
+
+ ok = true
+ return cert, in, ok
+}
+
+func marshalOpenSSHCertV01(cert *OpenSSHCertV01) []byte {
+ var pubKey []byte
+ switch cert.Key.(type) {
+ case *rsa.PublicKey:
+ k := cert.Key.(*rsa.PublicKey)
+ pubKey = marshalPubRSA(k)
+ case *dsa.PublicKey:
+ k := cert.Key.(*dsa.PublicKey)
+ pubKey = marshalPubDSA(k)
+ default:
+ panic("ssh: unknown public key type in cert")
+ }
+
+ sigKey := serializePublickey(cert.SignatureKey)
+
+ length := stringLength(cert.Nonce)
+ length += len(pubKey)
+ length += 8 // Length of Serial
+ length += 4 // Length of Type
+ length += stringLength([]byte(cert.KeyId))
+ length += lengthPrefixedNameListLength(cert.ValidPrincipals)
+ length += 8 // Length of ValidAfter
+ length += 8 // Length of ValidBefore
+ length += tupleListLength(cert.CriticalOptions)
+ length += tupleListLength(cert.Extensions)
+ length += stringLength(cert.Reserved)
+ length += stringLength(sigKey)
+ length += signatureLength(cert.Signature)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, cert.Nonce)
+ copy(r, pubKey)
+ r = r[len(pubKey):]
+ r = marshalUint64(r, cert.Serial)
+ r = marshalUint32(r, cert.Type)
+ r = marshalString(r, []byte(cert.KeyId))
+ r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
+ r = marshalUint64(r, uint64(cert.ValidAfter.Unix()))
+ r = marshalUint64(r, uint64(cert.ValidBefore.Unix()))
+ r = marshalTupleList(r, cert.CriticalOptions)
+ r = marshalTupleList(r, cert.Extensions)
+ r = marshalString(r, cert.Reserved)
+ r = marshalString(r, sigKey)
+ r = marshalSignature(r, cert.Signature)
+ if len(r) > 0 {
+ panic("internal error")
+ }
+ return ret
+}
+
+func lengthPrefixedNameListLength(namelist []string) int {
+ length := 4 // length prefix for list
+ for _, name := range namelist {
+ length += 4 // length prefix for name
+ length += len(name)
+ }
+ return length
+}
+
+func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte {
+ length := uint32(lengthPrefixedNameListLength(namelist) - 4)
+ to = marshalUint32(to, length)
+ for _, name := range namelist {
+ to = marshalString(to, []byte(name))
+ }
+ return to
+}
+
+func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) {
+ list, rest, ok := parseString(in)
+ if !ok {
+ return
+ }
+
+ for len(list) > 0 {
+ var next []byte
+ var ok bool
+ next, list, ok = parseString(list)
+ if !ok {
+ return nil, nil, false
+ }
+ out = append(out, string(next))
+ }
+ ok = true
+ return
+}
+
+func tupleListLength(tupleList []tuple) int {
+ length := 4 // length prefix for list
+ for _, t := range tupleList {
+ length += 4 // length prefix for t.Name
+ length += len(t.Name)
+ length += 4 // length prefix for t.Data
+ length += len(t.Data)
+ }
+ return length
+}
+
+func marshalTupleList(to []byte, tuplelist []tuple) []byte {
+ length := uint32(tupleListLength(tuplelist) - 4)
+ to = marshalUint32(to, length)
+ for _, t := range tuplelist {
+ to = marshalString(to, []byte(t.Name))
+ to = marshalString(to, []byte(t.Data))
+ }
+ return to
+}
+
+func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
+ list, rest, ok := parseString(in)
+ if !ok {
+ return
+ }
+
+ for len(list) > 0 {
+ var name, data []byte
+ var ok bool
+ name, list, ok = parseString(list)
+ if !ok {
+ return nil, nil, false
+ }
+ data, list, ok = parseString(list)
+ if !ok {
+ return nil, nil, false
+ }
+ out = append(out, tuple{string(name), string(data)})
+ }
+ ok = true
+ return
+}
+
+func signatureLength(sig *signature) int {
+ length := 4 // length prefix for signature
+ length += stringLength([]byte(sig.Format))
+ length += stringLength(sig.Blob)
+ return length
+}
+
+func marshalSignature(to []byte, sig *signature) []byte {
+ length := uint32(signatureLength(sig) - 4)
+ to = marshalUint32(to, length)
+ to = marshalString(to, []byte(sig.Format))
+ to = marshalString(to, sig.Blob)
+ return to
+}
+
+func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
+ var sigBytes, format []byte
+ sig := new(signature)
+
+ if sigBytes, rest, ok = parseString(in); !ok {
+ return
+ }
+
+ if format, sigBytes, ok = parseString(sigBytes); !ok {
+ return
+ }
+ sig.Format = string(format)
+
+ if sig.Blob, sigBytes, ok = parseString(sigBytes); !ok {
+ return
+ }
+
+ return sig, rest, ok
+}
diff --git a/ssh/common.go b/ssh/common.go
index 8850382..e8652dd 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -16,6 +16,7 @@
const (
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
hostAlgoRSA = "ssh-rsa"
+ hostAlgoDSA = "ssh-dss"
macSHA196 = "hmac-sha1-96"
compressionNone = "none"
serviceUserAuth = "ssh-userauth"
@@ -144,6 +145,15 @@
// serialize a signed slice according to RFC 4254 6.6.
func serializeSignature(algoname string, sig []byte) []byte {
+ switch algoname {
+ // The corresponding private key to a public certificate is always a normal
+ // private key. For signature serialization purposes, ensure we use the
+ // proper ssh-rsa or ssh-dss algo name in case the public cert algo name is passed.
+ case hostAlgoRSACertV01:
+ algoname = "ssh-rsa"
+ case hostAlgoDSACertV01:
+ algoname = "ssh-dss"
+ }
length := stringLength([]byte(algoname))
length += stringLength(sig)
@@ -156,33 +166,25 @@
// serialize a *rsa.PublicKey or *dsa.PublicKey according to RFC 4253 6.6.
func serializePublickey(key interface{}) []byte {
+ var pubKeyBytes []byte
algoname := algoName(key)
switch key := key.(type) {
case *rsa.PublicKey:
- e := new(big.Int).SetInt64(int64(key.E))
- length := stringLength([]byte(algoname))
- length += intLength(e)
- length += intLength(key.N)
- ret := make([]byte, length)
- r := marshalString(ret, []byte(algoname))
- r = marshalInt(r, e)
- marshalInt(r, key.N)
- return ret
+ pubKeyBytes = marshalPubRSA(key)
case *dsa.PublicKey:
- length := stringLength([]byte(algoname))
- length += intLength(key.P)
- length += intLength(key.Q)
- length += intLength(key.G)
- length += intLength(key.Y)
- ret := make([]byte, length)
- r := marshalString(ret, []byte(algoname))
- r = marshalInt(r, key.P)
- r = marshalInt(r, key.Q)
- r = marshalInt(r, key.G)
- marshalInt(r, key.Y)
- return ret
+ pubKeyBytes = marshalPubDSA(key)
+ case *OpenSSHCertV01:
+ pubKeyBytes = marshalOpenSSHCertV01(key)
+ default:
+ panic("unexpected key type")
}
- panic("unexpected key type")
+
+ length := stringLength([]byte(algoname))
+ length += len(pubKeyBytes)
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(algoname))
+ copy(r, pubKeyBytes)
+ return ret
}
func algoName(key interface{}) string {
@@ -191,6 +193,8 @@
return "ssh-rsa"
case *dsa.PublicKey:
return "ssh-dss"
+ case *OpenSSHCertV01:
+ return algoName(key.(*OpenSSHCertV01).Key) + "-cert-v01@openssh.com"
}
panic("unexpected key type")
}
diff --git a/ssh/keys.go b/ssh/keys.go
new file mode 100644
index 0000000..8322697
--- /dev/null
+++ b/ssh/keys.go
@@ -0,0 +1,120 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+ "crypto/dsa"
+ "crypto/rsa"
+ "math/big"
+)
+
+// parsePubKey parses a public key according to RFC 4253, section 6.6.
+func parsePubKey(in []byte) (out interface{}, rest []byte, ok bool) {
+ algo, in, ok := parseString(in)
+ if !ok {
+ return
+ }
+
+ switch string(algo) {
+ case hostAlgoRSA:
+ return parseRSA(in)
+ case hostAlgoDSA:
+ return parseDSA(in)
+ case hostAlgoRSACertV01, hostAlgoDSACertV01:
+ return parseOpenSSHCertV01(in, string(algo))
+ }
+ panic("ssh: unknown public key type")
+}
+
+// parseRSA parses an RSA key according to RFC 4253, section 6.6.
+func parseRSA(in []byte) (out *rsa.PublicKey, rest []byte, ok bool) {
+ key := new(rsa.PublicKey)
+
+ bigE, in, ok := parseInt(in)
+ if !ok || bigE.BitLen() > 24 {
+ return
+ }
+ e := bigE.Int64()
+ if e < 3 || e&1 == 0 {
+ ok = false
+ return
+ }
+ key.E = int(e)
+
+ if key.N, in, ok = parseInt(in); !ok {
+ return
+ }
+
+ ok = true
+ return key, in, ok
+}
+
+// parseDSA parses an DSA key according to RFC 4253, section 6.6.
+func parseDSA(in []byte) (out *dsa.PublicKey, rest []byte, ok bool) {
+ key := new(dsa.PublicKey)
+
+ if key.P, in, ok = parseInt(in); !ok {
+ return
+ }
+
+ if key.Q, in, ok = parseInt(in); !ok {
+ return
+ }
+
+ if key.G, in, ok = parseInt(in); !ok {
+ return
+ }
+
+ if key.Y, in, ok = parseInt(in); !ok {
+ return
+ }
+
+ ok = true
+ return key, in, ok
+}
+
+// marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
+func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
+ e := new(big.Int).SetInt64(int64(priv.E))
+ length := stringLength([]byte(hostAlgoRSA))
+ length += intLength(e)
+ length += intLength(priv.N)
+
+ ret := make([]byte, length)
+ r := marshalString(ret, []byte(hostAlgoRSA))
+ r = marshalInt(r, e)
+ r = marshalInt(r, priv.N)
+
+ return ret
+}
+
+// marshalPubRSA serializes an RSA public key according to RFC 4253, section 6.6.
+func marshalPubRSA(key *rsa.PublicKey) []byte {
+ e := new(big.Int).SetInt64(int64(key.E))
+ length := intLength(e)
+ length += intLength(key.N)
+
+ ret := make([]byte, length)
+ r := marshalInt(ret, e)
+ r = marshalInt(r, key.N)
+
+ return ret
+}
+
+// marshalPubDSA serializes an DSA public key according to RFC 4253, section 6.6.
+func marshalPubDSA(key *dsa.PublicKey) []byte {
+ length := intLength(key.P)
+ length += intLength(key.Q)
+ length += intLength(key.G)
+ length += intLength(key.Y)
+
+ ret := make([]byte, length)
+ r := marshalInt(ret, key.P)
+ r = marshalInt(r, key.Q)
+ r = marshalInt(r, key.G)
+ marshalInt(r, key.Y)
+
+ return ret
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index 34ad131..eac8b8d 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -448,6 +448,23 @@
return
}
+func parseUint64(in []byte) (out uint64, rest []byte, ok bool) {
+ if len(in) < 8 {
+ return
+ }
+ out = uint64(in[0])<<56 |
+ uint64(in[1])<<48 |
+ uint64(in[2])<<40 |
+ uint64(in[3])<<32 |
+ uint64(in[4])<<24 |
+ uint64(in[5])<<16 |
+ uint64(in[6])<<8 |
+ uint64(in[7])
+ rest = in[8:]
+ ok = true
+ return
+}
+
func nameListLength(namelist []string) int {
length := 4 /* uint32 length prefix */
for i, name := range namelist {
diff --git a/ssh/server.go b/ssh/server.go
index 1af0703..ec367c8 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -67,49 +67,10 @@
return err
}
- s.rsaSerialized = marshalRSA(s.rsa)
+ s.rsaSerialized = marshalPrivRSA(s.rsa)
return nil
}
-// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
-func marshalRSA(priv *rsa.PrivateKey) []byte {
- e := new(big.Int).SetInt64(int64(priv.E))
- length := stringLength([]byte(hostAlgoRSA))
- length += intLength(e)
- length += intLength(priv.N)
-
- ret := make([]byte, length)
- r := marshalString(ret, []byte(hostAlgoRSA))
- r = marshalInt(r, e)
- r = marshalInt(r, priv.N)
-
- return ret
-}
-
-// parseRSA parses an RSA key according to RFC 4256, section 6.6.
-func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
- algo, in, ok := parseString(in)
- if !ok || string(algo) != hostAlgoRSA {
- return nil, false
- }
- bigE, in, ok := parseInt(in)
- if !ok || bigE.BitLen() > 24 {
- return nil, false
- }
- e := bigE.Int64()
- if e < 3 || e&1 == 0 {
- return nil, false
- }
- N, in, ok := parseInt(in)
- if !ok || len(in) > 0 {
- return nil, false
- }
- return &rsa.PublicKey{
- N: N,
- E: int(e),
- }, true
-}
-
func parseRSASig(in []byte) (sig []byte, ok bool) {
algo, in, ok := parseString(in)
if !ok || string(algo) != hostAlgoRSA {
@@ -485,7 +446,11 @@
h := hashFunc.New()
h.Write(signedData)
digest := h.Sum(nil)
- rsaKey, ok := parseRSA(pubKey)
+ key, _, ok := parsePubKey(pubKey)
+ if !ok {
+ return ParseError{msgUserAuthRequest}
+ }
+ rsaKey, ok := key.(*rsa.PublicKey)
if !ok {
return ParseError{msgUserAuthRequest}
}