go.crypto/ssh: add client support for OpenSSH certificates
Refactor key parsing, marshaling, and serialization to be a bit more flexible

R=agl, dave, djm
CC=golang-dev
https://golang.org/cl/5650067
diff --git a/ssh/certs.go b/ssh/certs.go
new file mode 100644
index 0000000..59430ea
--- /dev/null
+++ b/ssh/certs.go
@@ -0,0 +1,306 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+// References
+//   [PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
+
+import (
+	"crypto/dsa"
+	"crypto/rsa"
+	"time"
+)
+
+// String constants in [PROTOCOL.certkeys] for certificate algorithm names.
+const (
+	hostAlgoRSACertV01      = "ssh-rsa-cert-v01@openssh.com"
+	hostAlgoDSACertV01      = "ssh-dss-cert-v01@openssh.com"
+	hostAlgoECDSA256CertV01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
+	hostAlgoECDSA384CertV01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
+	hostAlgoECDSA521CertV01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
+)
+
+// Certificate types are used to specify whether a certificate is for identification
+// of a user or a host.  Current identities are defined in [PROTOCOL.certkeys].
+const (
+	UserCert = 1
+	HostCert = 2
+)
+
+type signature struct {
+	Format string
+	Blob   []byte
+}
+
+type tuple struct {
+	Name string
+	Data string
+}
+
+// An OpenSSHCertV01 represents an OpenSSH certificate as defined in
+// [PROTOCOL.certkeys] rev 1.8. Supported formats include
+// ssh-rsa-cert-v01@openssh.com and ssh-dss-cert-v01@openssh.com.
+type OpenSSHCertV01 struct {
+	Nonce                   []byte
+	Key                     interface{} // rsa or dsa *PublicKey
+	Serial                  uint64
+	Type                    uint32
+	KeyId                   string
+	ValidPrincipals         []string
+	ValidAfter, ValidBefore time.Time
+	CriticalOptions         []tuple
+	Extensions              []tuple
+	Reserved                []byte
+	SignatureKey            interface{} // rsa, dsa, or ecdsa *PublicKey
+	Signature               *signature
+}
+
+func parseOpenSSHCertV01(in []byte, algo string) (out *OpenSSHCertV01, rest []byte, ok bool) {
+	cert := new(OpenSSHCertV01)
+
+	if cert.Nonce, in, ok = parseString(in); !ok {
+		return
+	}
+
+	switch algo {
+	case hostAlgoRSACertV01:
+		var rsaPubKey *rsa.PublicKey
+		if rsaPubKey, in, ok = parseRSA(in); !ok {
+			return
+		}
+		cert.Key = rsaPubKey
+	case hostAlgoDSACertV01:
+		var dsaPubKey *dsa.PublicKey
+		if dsaPubKey, in, ok = parseDSA(in); !ok {
+			return
+		}
+		cert.Key = dsaPubKey
+	default:
+		return
+	}
+
+	if cert.Serial, in, ok = parseUint64(in); !ok {
+		return
+	}
+
+	if cert.Type, in, ok = parseUint32(in); !ok || cert.Type != UserCert && cert.Type != HostCert {
+		return
+	}
+
+	keyId, in, ok := parseString(in)
+	if !ok {
+		return
+	}
+	cert.KeyId = string(keyId)
+
+	if cert.ValidPrincipals, in, ok = parseLengthPrefixedNameList(in); !ok {
+		return
+	}
+
+	va, in, ok := parseUint64(in)
+	if !ok {
+		return
+	}
+	cert.ValidAfter = time.Unix(int64(va), 0)
+
+	vb, in, ok := parseUint64(in)
+	if !ok {
+		return
+	}
+	cert.ValidBefore = time.Unix(int64(vb), 0)
+
+	if cert.CriticalOptions, in, ok = parseTupleList(in); !ok {
+		return
+	}
+
+	if cert.Extensions, in, ok = parseTupleList(in); !ok {
+		return
+	}
+
+	if cert.Reserved, in, ok = parseString(in); !ok {
+		return
+	}
+
+	sigKey, in, ok := parseString(in)
+	if !ok {
+		return
+	}
+	if cert.SignatureKey, _, ok = parsePubKey(sigKey); !ok {
+		return
+	}
+
+	if cert.Signature, in, ok = parseSignature(in); !ok {
+		return
+	}
+
+	ok = true
+	return cert, in, ok
+}
+
+func marshalOpenSSHCertV01(cert *OpenSSHCertV01) []byte {
+	var pubKey []byte
+	switch cert.Key.(type) {
+	case *rsa.PublicKey:
+		k := cert.Key.(*rsa.PublicKey)
+		pubKey = marshalPubRSA(k)
+	case *dsa.PublicKey:
+		k := cert.Key.(*dsa.PublicKey)
+		pubKey = marshalPubDSA(k)
+	default:
+		panic("ssh: unknown public key type in cert")
+	}
+
+	sigKey := serializePublickey(cert.SignatureKey)
+
+	length := stringLength(cert.Nonce)
+	length += len(pubKey)
+	length += 8 // Length of Serial
+	length += 4 // Length of Type
+	length += stringLength([]byte(cert.KeyId))
+	length += lengthPrefixedNameListLength(cert.ValidPrincipals)
+	length += 8 // Length of ValidAfter
+	length += 8 // Length of ValidBefore
+	length += tupleListLength(cert.CriticalOptions)
+	length += tupleListLength(cert.Extensions)
+	length += stringLength(cert.Reserved)
+	length += stringLength(sigKey)
+	length += signatureLength(cert.Signature)
+
+	ret := make([]byte, length)
+	r := marshalString(ret, cert.Nonce)
+	copy(r, pubKey)
+	r = r[len(pubKey):]
+	r = marshalUint64(r, cert.Serial)
+	r = marshalUint32(r, cert.Type)
+	r = marshalString(r, []byte(cert.KeyId))
+	r = marshalLengthPrefixedNameList(r, cert.ValidPrincipals)
+	r = marshalUint64(r, uint64(cert.ValidAfter.Unix()))
+	r = marshalUint64(r, uint64(cert.ValidBefore.Unix()))
+	r = marshalTupleList(r, cert.CriticalOptions)
+	r = marshalTupleList(r, cert.Extensions)
+	r = marshalString(r, cert.Reserved)
+	r = marshalString(r, sigKey)
+	r = marshalSignature(r, cert.Signature)
+	if len(r) > 0 {
+		panic("internal error")
+	}
+	return ret
+}
+
+func lengthPrefixedNameListLength(namelist []string) int {
+	length := 4 // length prefix for list
+	for _, name := range namelist {
+		length += 4 // length prefix for name
+		length += len(name)
+	}
+	return length
+}
+
+func marshalLengthPrefixedNameList(to []byte, namelist []string) []byte {
+	length := uint32(lengthPrefixedNameListLength(namelist) - 4)
+	to = marshalUint32(to, length)
+	for _, name := range namelist {
+		to = marshalString(to, []byte(name))
+	}
+	return to
+}
+
+func parseLengthPrefixedNameList(in []byte) (out []string, rest []byte, ok bool) {
+	list, rest, ok := parseString(in)
+	if !ok {
+		return
+	}
+
+	for len(list) > 0 {
+		var next []byte
+		var ok bool
+		next, list, ok = parseString(list)
+		if !ok {
+			return nil, nil, false
+		}
+		out = append(out, string(next))
+	}
+	ok = true
+	return
+}
+
+func tupleListLength(tupleList []tuple) int {
+	length := 4 // length prefix for list
+	for _, t := range tupleList {
+		length += 4 // length prefix for t.Name
+		length += len(t.Name)
+		length += 4 // length prefix for t.Data
+		length += len(t.Data)
+	}
+	return length
+}
+
+func marshalTupleList(to []byte, tuplelist []tuple) []byte {
+	length := uint32(tupleListLength(tuplelist) - 4)
+	to = marshalUint32(to, length)
+	for _, t := range tuplelist {
+		to = marshalString(to, []byte(t.Name))
+		to = marshalString(to, []byte(t.Data))
+	}
+	return to
+}
+
+func parseTupleList(in []byte) (out []tuple, rest []byte, ok bool) {
+	list, rest, ok := parseString(in)
+	if !ok {
+		return
+	}
+
+	for len(list) > 0 {
+		var name, data []byte
+		var ok bool
+		name, list, ok = parseString(list)
+		if !ok {
+			return nil, nil, false
+		}
+		data, list, ok = parseString(list)
+		if !ok {
+			return nil, nil, false
+		}
+		out = append(out, tuple{string(name), string(data)})
+	}
+	ok = true
+	return
+}
+
+func signatureLength(sig *signature) int {
+	length := 4 // length prefix for signature
+	length += stringLength([]byte(sig.Format))
+	length += stringLength(sig.Blob)
+	return length
+}
+
+func marshalSignature(to []byte, sig *signature) []byte {
+	length := uint32(signatureLength(sig) - 4)
+	to = marshalUint32(to, length)
+	to = marshalString(to, []byte(sig.Format))
+	to = marshalString(to, sig.Blob)
+	return to
+}
+
+func parseSignature(in []byte) (out *signature, rest []byte, ok bool) {
+	var sigBytes, format []byte
+	sig := new(signature)
+
+	if sigBytes, rest, ok = parseString(in); !ok {
+		return
+	}
+
+	if format, sigBytes, ok = parseString(sigBytes); !ok {
+		return
+	}
+	sig.Format = string(format)
+
+	if sig.Blob, sigBytes, ok = parseString(sigBytes); !ok {
+		return
+	}
+
+	return sig, rest, ok
+}
diff --git a/ssh/common.go b/ssh/common.go
index 8850382..e8652dd 100644
--- a/ssh/common.go
+++ b/ssh/common.go
@@ -16,6 +16,7 @@
 const (
 	kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
 	hostAlgoRSA     = "ssh-rsa"
+	hostAlgoDSA     = "ssh-dss"
 	macSHA196       = "hmac-sha1-96"
 	compressionNone = "none"
 	serviceUserAuth = "ssh-userauth"
@@ -144,6 +145,15 @@
 
 // serialize a signed slice according to RFC 4254 6.6.
 func serializeSignature(algoname string, sig []byte) []byte {
+	switch algoname {
+	// The corresponding private key to a public certificate is always a normal
+	// private key.  For signature serialization purposes, ensure we use the
+	// proper ssh-rsa or ssh-dss algo name in case the public cert algo name is passed.
+	case hostAlgoRSACertV01:
+		algoname = "ssh-rsa"
+	case hostAlgoDSACertV01:
+		algoname = "ssh-dss"
+	}
 	length := stringLength([]byte(algoname))
 	length += stringLength(sig)
 
@@ -156,33 +166,25 @@
 
 // serialize a *rsa.PublicKey or *dsa.PublicKey according to RFC 4253 6.6.
 func serializePublickey(key interface{}) []byte {
+	var pubKeyBytes []byte
 	algoname := algoName(key)
 	switch key := key.(type) {
 	case *rsa.PublicKey:
-		e := new(big.Int).SetInt64(int64(key.E))
-		length := stringLength([]byte(algoname))
-		length += intLength(e)
-		length += intLength(key.N)
-		ret := make([]byte, length)
-		r := marshalString(ret, []byte(algoname))
-		r = marshalInt(r, e)
-		marshalInt(r, key.N)
-		return ret
+		pubKeyBytes = marshalPubRSA(key)
 	case *dsa.PublicKey:
-		length := stringLength([]byte(algoname))
-		length += intLength(key.P)
-		length += intLength(key.Q)
-		length += intLength(key.G)
-		length += intLength(key.Y)
-		ret := make([]byte, length)
-		r := marshalString(ret, []byte(algoname))
-		r = marshalInt(r, key.P)
-		r = marshalInt(r, key.Q)
-		r = marshalInt(r, key.G)
-		marshalInt(r, key.Y)
-		return ret
+		pubKeyBytes = marshalPubDSA(key)
+	case *OpenSSHCertV01:
+		pubKeyBytes = marshalOpenSSHCertV01(key)
+	default:
+		panic("unexpected key type")
 	}
-	panic("unexpected key type")
+
+	length := stringLength([]byte(algoname))
+	length += len(pubKeyBytes)
+	ret := make([]byte, length)
+	r := marshalString(ret, []byte(algoname))
+	copy(r, pubKeyBytes)
+	return ret
 }
 
 func algoName(key interface{}) string {
@@ -191,6 +193,8 @@
 		return "ssh-rsa"
 	case *dsa.PublicKey:
 		return "ssh-dss"
+	case *OpenSSHCertV01:
+		return algoName(key.(*OpenSSHCertV01).Key) + "-cert-v01@openssh.com"
 	}
 	panic("unexpected key type")
 }
diff --git a/ssh/keys.go b/ssh/keys.go
new file mode 100644
index 0000000..8322697
--- /dev/null
+++ b/ssh/keys.go
@@ -0,0 +1,120 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+	"crypto/dsa"
+	"crypto/rsa"
+	"math/big"
+)
+
+// parsePubKey parses a public key according to RFC 4253, section 6.6.
+func parsePubKey(in []byte) (out interface{}, rest []byte, ok bool) {
+	algo, in, ok := parseString(in)
+	if !ok {
+		return
+	}
+
+	switch string(algo) {
+	case hostAlgoRSA:
+		return parseRSA(in)
+	case hostAlgoDSA:
+		return parseDSA(in)
+	case hostAlgoRSACertV01, hostAlgoDSACertV01:
+		return parseOpenSSHCertV01(in, string(algo))
+	}
+	panic("ssh: unknown public key type")
+}
+
+// parseRSA parses an RSA key according to RFC 4253, section 6.6.
+func parseRSA(in []byte) (out *rsa.PublicKey, rest []byte, ok bool) {
+	key := new(rsa.PublicKey)
+
+	bigE, in, ok := parseInt(in)
+	if !ok || bigE.BitLen() > 24 {
+		return
+	}
+	e := bigE.Int64()
+	if e < 3 || e&1 == 0 {
+		ok = false
+		return
+	}
+	key.E = int(e)
+
+	if key.N, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	ok = true
+	return key, in, ok
+}
+
+// parseDSA parses an DSA key according to RFC 4253, section 6.6.
+func parseDSA(in []byte) (out *dsa.PublicKey, rest []byte, ok bool) {
+	key := new(dsa.PublicKey)
+
+	if key.P, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.Q, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.G, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	if key.Y, in, ok = parseInt(in); !ok {
+		return
+	}
+
+	ok = true
+	return key, in, ok
+}
+
+// marshalPrivRSA serializes an RSA private key according to RFC 4253, section 6.6.
+func marshalPrivRSA(priv *rsa.PrivateKey) []byte {
+	e := new(big.Int).SetInt64(int64(priv.E))
+	length := stringLength([]byte(hostAlgoRSA))
+	length += intLength(e)
+	length += intLength(priv.N)
+
+	ret := make([]byte, length)
+	r := marshalString(ret, []byte(hostAlgoRSA))
+	r = marshalInt(r, e)
+	r = marshalInt(r, priv.N)
+
+	return ret
+}
+
+// marshalPubRSA serializes an RSA public key according to RFC 4253, section 6.6.
+func marshalPubRSA(key *rsa.PublicKey) []byte {
+	e := new(big.Int).SetInt64(int64(key.E))
+	length := intLength(e)
+	length += intLength(key.N)
+
+	ret := make([]byte, length)
+	r := marshalInt(ret, e)
+	r = marshalInt(r, key.N)
+
+	return ret
+}
+
+// marshalPubDSA serializes an DSA public key according to RFC 4253, section 6.6.
+func marshalPubDSA(key *dsa.PublicKey) []byte {
+	length := intLength(key.P)
+	length += intLength(key.Q)
+	length += intLength(key.G)
+	length += intLength(key.Y)
+
+	ret := make([]byte, length)
+	r := marshalInt(ret, key.P)
+	r = marshalInt(r, key.Q)
+	r = marshalInt(r, key.G)
+	marshalInt(r, key.Y)
+
+	return ret
+}
diff --git a/ssh/messages.go b/ssh/messages.go
index 34ad131..eac8b8d 100644
--- a/ssh/messages.go
+++ b/ssh/messages.go
@@ -448,6 +448,23 @@
 	return
 }
 
+func parseUint64(in []byte) (out uint64, rest []byte, ok bool) {
+	if len(in) < 8 {
+		return
+	}
+	out = uint64(in[0])<<56 |
+		uint64(in[1])<<48 |
+		uint64(in[2])<<40 |
+		uint64(in[3])<<32 |
+		uint64(in[4])<<24 |
+		uint64(in[5])<<16 |
+		uint64(in[6])<<8 |
+		uint64(in[7])
+	rest = in[8:]
+	ok = true
+	return
+}
+
 func nameListLength(namelist []string) int {
 	length := 4 /* uint32 length prefix */
 	for i, name := range namelist {
diff --git a/ssh/server.go b/ssh/server.go
index 1af0703..ec367c8 100644
--- a/ssh/server.go
+++ b/ssh/server.go
@@ -67,49 +67,10 @@
 		return err
 	}
 
-	s.rsaSerialized = marshalRSA(s.rsa)
+	s.rsaSerialized = marshalPrivRSA(s.rsa)
 	return nil
 }
 
-// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
-func marshalRSA(priv *rsa.PrivateKey) []byte {
-	e := new(big.Int).SetInt64(int64(priv.E))
-	length := stringLength([]byte(hostAlgoRSA))
-	length += intLength(e)
-	length += intLength(priv.N)
-
-	ret := make([]byte, length)
-	r := marshalString(ret, []byte(hostAlgoRSA))
-	r = marshalInt(r, e)
-	r = marshalInt(r, priv.N)
-
-	return ret
-}
-
-// parseRSA parses an RSA key according to RFC 4256, section 6.6.
-func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
-	algo, in, ok := parseString(in)
-	if !ok || string(algo) != hostAlgoRSA {
-		return nil, false
-	}
-	bigE, in, ok := parseInt(in)
-	if !ok || bigE.BitLen() > 24 {
-		return nil, false
-	}
-	e := bigE.Int64()
-	if e < 3 || e&1 == 0 {
-		return nil, false
-	}
-	N, in, ok := parseInt(in)
-	if !ok || len(in) > 0 {
-		return nil, false
-	}
-	return &rsa.PublicKey{
-		N: N,
-		E: int(e),
-	}, true
-}
-
 func parseRSASig(in []byte) (sig []byte, ok bool) {
 	algo, in, ok := parseString(in)
 	if !ok || string(algo) != hostAlgoRSA {
@@ -485,7 +446,11 @@
 					h := hashFunc.New()
 					h.Write(signedData)
 					digest := h.Sum(nil)
-					rsaKey, ok := parseRSA(pubKey)
+					key, _, ok := parsePubKey(pubKey)
+					if !ok {
+						return ParseError{msgUserAuthRequest}
+					}
+					rsaKey, ok := key.(*rsa.PublicKey)
 					if !ok {
 						return ParseError{msgUserAuthRequest}
 					}