acme: improve TLSSNI{01,02}ChallengeCert methods
This is a split of https://go-review.googlesource.com/23970 (patch set 8)
to address only Client changes:
1. Expose expected server name value of TLS SNI ClientHello message
2. Fix a bug where returned error value was nil.
Change-Id: I21f571652e9bbef80a2222dc34fce767270b7c48
Reviewed-on: https://go-review.googlesource.com/26852
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/internal/acme/acme.go b/acme/internal/acme/acme.go
index 86398e3..b6d7efb 100644
--- a/acme/internal/acme/acme.go
+++ b/acme/internal/acme/acme.go
@@ -380,16 +380,22 @@
// For more details on TLS-SNI-01 see https://tools.ietf.org/html/draft-ietf-acme-acme-01#section-7.3.
//
// The token argument is a Challenge.Token value.
-// The returned certificate is valid for the next 24 hours.
-func (c *Client) TLSSNI01ChallengeCert(token string) (tls.Certificate, error) {
+//
+// The returned certificate is valid for the next 24 hours and must be presented only when
+// the server name of the client hello matches exactly the returned name value.
+func (c *Client) TLSSNI01ChallengeCert(token string) (cert tls.Certificate, name string, err error) {
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
- return tls.Certificate{}, nil
+ return tls.Certificate{}, "", err
}
b := sha256.Sum256([]byte(ka))
h := hex.EncodeToString(b[:])
- name := fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:])
- return tlsChallengeCert(name)
+ name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:])
+ cert, err = tlsChallengeCert(name)
+ if err != nil {
+ return tls.Certificate{}, "", err
+ }
+ return cert, name, nil
}
// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response.
@@ -398,21 +404,27 @@
// https://tools.ietf.org/html/draft-ietf-acme-acme-03#section-7.3.
//
// The token argument is a Challenge.Token value.
-// The returned certificate is valid for the next 24 hours.
-func (c *Client) TLSSNI02ChallengeCert(token string) (tls.Certificate, error) {
+//
+// The returned certificate is valid for the next 24 hours and must be presented only when
+// the server name in the client hello matches exactly the returned name value.
+func (c *Client) TLSSNI02ChallengeCert(token string) (cert tls.Certificate, name string, err error) {
b := sha256.Sum256([]byte(token))
h := hex.EncodeToString(b[:])
sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:])
ka, err := keyAuth(c.Key.Public(), token)
if err != nil {
- return tls.Certificate{}, nil
+ return tls.Certificate{}, "", err
}
b = sha256.Sum256([]byte(ka))
h = hex.EncodeToString(b[:])
sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:])
- return tlsChallengeCert(sanA, sanB)
+ cert, err = tlsChallengeCert(sanA, sanB)
+ if err != nil {
+ return tls.Certificate{}, "", err
+ }
+ return cert, sanA, nil
}
func (c *Client) httpClient() *http.Client {
diff --git a/acme/internal/acme/acme_test.go b/acme/internal/acme/acme_test.go
index 45e9f30..f8c1d53 100644
--- a/acme/internal/acme/acme_test.go
+++ b/acme/internal/acme/acme_test.go
@@ -16,6 +16,7 @@
"net/http"
"net/http/httptest"
"reflect"
+ "sort"
"strings"
"testing"
"time"
@@ -775,7 +776,7 @@
)
client := &Client{Key: testKey}
- tlscert, err := client.TLSSNI01ChallengeCert(token)
+ tlscert, name, err := client.TLSSNI01ChallengeCert(token)
if err != nil {
t.Fatal(err)
}
@@ -788,7 +789,10 @@
t.Fatal(err)
}
if len(cert.DNSNames) != 1 || cert.DNSNames[0] != san {
- t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, san)
+ t.Fatalf("cert.DNSNames = %v; want %q", cert.DNSNames, san)
+ }
+ if cert.DNSNames[0] != name {
+ t.Errorf("cert.DNSNames[0] != name: %q vs %q", cert.DNSNames[0], name)
}
}
func TestTLSSNI02ChallengeCert(t *testing.T) {
@@ -801,7 +805,7 @@
)
client := &Client{Key: testKey}
- tlscert, err := client.TLSSNI02ChallengeCert(token)
+ tlscert, name, err := client.TLSSNI02ChallengeCert(token)
if err != nil {
t.Fatal(err)
}
@@ -815,6 +819,11 @@
}
names := []string{sanA, sanB}
if !reflect.DeepEqual(cert.DNSNames, names) {
- t.Errorf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names)
+ t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names)
+ }
+ sort.Strings(cert.DNSNames)
+ i := sort.SearchStrings(cert.DNSNames, name)
+ if i >= len(cert.DNSNames) || cert.DNSNames[i] != name {
+ t.Errorf("%v doesn't have %q", cert.DNSNames, name)
}
}