acme: support custom crypto.Signer implementations

Currently, only rsa.PrivateKey and ecdsa.PrivateKey are supported
when creating JWS signatures. However, it is unnecessarily limiting
because any crypto.Signer implementation can sign a digest
in the appropriate format.

This change uses key.Public() instead of type-asserting the private
key which allows for a custom crypto.Signer implementation.
For instance, a key stored in a hardware module where the latter
does the actual signing without the key ever leaving its boundaries.

Change-Id: Ie7930ea2ba8c49dde7107ff074ae34abec05bdb9
Reviewed-on: https://go-review.googlesource.com/c/145137
Run-TryBot: Alex Vaghin <ddos@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
diff --git a/acme/acme.go b/acme/acme.go
index 7df6476..c6fd998 100644
--- a/acme/acme.go
+++ b/acme/acme.go
@@ -77,6 +77,10 @@
 type Client struct {
 	// Key is the account key used to register with a CA and sign requests.
 	// Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey.
+	//
+	// The following algorithms are supported:
+	// RS256, ES256, ES384 and ES512.
+	// See RFC7518 for more details about the algorithms.
 	Key crypto.Signer
 
 	// HTTPClient optionally specifies an HTTP client to use
diff --git a/acme/jws.go b/acme/jws.go
index 6cbca25..1093b50 100644
--- a/acme/jws.go
+++ b/acme/jws.go
@@ -25,7 +25,7 @@
 	if err != nil {
 		return nil, err
 	}
-	alg, sha := jwsHasher(key)
+	alg, sha := jwsHasher(key.Public())
 	if alg == "" || !sha.Available() {
 		return nil, ErrUnsupportedKey
 	}
@@ -97,13 +97,16 @@
 }
 
 // jwsSign signs the digest using the given key.
-// It returns ErrUnsupportedKey if the key type is unknown.
-// The hash is used only for RSA keys.
+// The hash is unused for ECDSA keys.
+//
+// Note: non-stdlib crypto.Signer implementations are expected to return
+// the signature in the format as specified in RFC7518.
+// See https://tools.ietf.org/html/rfc7518 for more details.
 func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) {
-	switch key := key.(type) {
-	case *rsa.PrivateKey:
-		return key.Sign(rand.Reader, digest, hash)
-	case *ecdsa.PrivateKey:
+	if key, ok := key.(*ecdsa.PrivateKey); ok {
+		// The key.Sign method of ecdsa returns ASN1-encoded signature.
+		// So, we use the package Sign function instead
+		// to get R and S values directly and format the result accordingly.
 		r, s, err := ecdsa.Sign(rand.Reader, key, digest)
 		if err != nil {
 			return nil, err
@@ -118,18 +121,18 @@
 		copy(sig[size*2-len(sb):], sb)
 		return sig, nil
 	}
-	return nil, ErrUnsupportedKey
+	return key.Sign(rand.Reader, digest, hash)
 }
 
 // jwsHasher indicates suitable JWS algorithm name and a hash function
 // to use for signing a digest with the provided key.
 // It returns ("", 0) if the key is not supported.
-func jwsHasher(key crypto.Signer) (string, crypto.Hash) {
-	switch key := key.(type) {
-	case *rsa.PrivateKey:
+func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) {
+	switch pub := pub.(type) {
+	case *rsa.PublicKey:
 		return "RS256", crypto.SHA256
-	case *ecdsa.PrivateKey:
-		switch key.Params().Name {
+	case *ecdsa.PublicKey:
+		switch pub.Params().Name {
 		case "P-256":
 			return "ES256", crypto.SHA256
 		case "P-384":
diff --git a/acme/jws_test.go b/acme/jws_test.go
index 0ff0fb5..ee30b1e 100644
--- a/acme/jws_test.go
+++ b/acme/jws_test.go
@@ -5,6 +5,7 @@
 package acme
 
 import (
+	"crypto"
 	"crypto/ecdsa"
 	"crypto/elliptic"
 	"crypto/rsa"
@@ -13,6 +14,7 @@
 	"encoding/json"
 	"encoding/pem"
 	"fmt"
+	"io"
 	"math/big"
 	"testing"
 )
@@ -241,6 +243,79 @@
 	}
 }
 
+type customTestSigner struct {
+	sig []byte
+	pub crypto.PublicKey
+}
+
+func (s *customTestSigner) Public() crypto.PublicKey { return s.pub }
+func (s *customTestSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) {
+	return s.sig, nil
+}
+
+func TestJWSEncodeJSONCustom(t *testing.T) {
+	claims := struct{ Msg string }{"hello"}
+	const (
+		// printf '{"Msg":"hello"}' | base64 | tr -d '=' | tr '/+' '_-'
+		payload = "eyJNc2ciOiJoZWxsbyJ9"
+		// printf 'testsig' | base64 | tr -d '='
+		testsig = "dGVzdHNpZw"
+
+		// printf '{"alg":"ES256","jwk":{"crv":"P-256","kty":"EC","x":<testKeyECPubY>,"y":<testKeyECPubY>,"nonce":"nonce"}' | \
+		// base64 | tr -d '=' | tr '/+' '_-'
+		es256phead = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJjcnYiOiJQLTI1NiIsImt0eSI6IkVDIiwieCI6IjVsaEV1" +
+			"ZzV4SzR4QkRaMm5BYmF4THRhTGl2ODVieEo3ZVBkMWRrTzIzSFEiLCJ5IjoiNGFpSzcyc0JlVUFH" +
+			"a3YwVGFMc213b2tZVVl5TnhHc1M1RU1JS3dzTklLayJ9LCJub25jZSI6Im5vbmNlIn0"
+
+		// {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce"}
+		rs256phead = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" +
+			"IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" +
+			"SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" +
+			"QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" +
+			"VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" +
+			"NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" +
+			"QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" +
+			"bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" +
+			"ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" +
+			"b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" +
+			"UVEifSwibm9uY2UiOiJub25jZSJ9"
+	)
+
+	tt := []struct {
+		alg, phead string
+		pub        crypto.PublicKey
+	}{
+		{"RS256", rs256phead, testKey.Public()},
+		{"ES256", es256phead, testKeyEC.Public()},
+	}
+	for _, tc := range tt {
+		tc := tc
+		t.Run(tc.alg, func(t *testing.T) {
+			signer := &customTestSigner{
+				sig: []byte("testsig"),
+				pub: tc.pub,
+			}
+			b, err := jwsEncodeJSON(claims, signer, "nonce")
+			if err != nil {
+				t.Fatal(err)
+			}
+			var j struct{ Protected, Payload, Signature string }
+			if err := json.Unmarshal(b, &j); err != nil {
+				t.Fatal(err)
+			}
+			if j.Protected != tc.phead {
+				t.Errorf("j.Protected = %q\nwant %q", j.Protected, tc.phead)
+			}
+			if j.Payload != payload {
+				t.Errorf("j.Payload = %q\nwant %q", j.Payload, payload)
+			}
+			if j.Signature != testsig {
+				t.Errorf("j.Signature = %q\nwant %q", j.Signature, testsig)
+			}
+		})
+	}
+}
+
 func TestJWKThumbprintRSA(t *testing.T) {
 	// Key example from RFC 7638
 	const base64N = "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAt" +