acme: context-aware Client methods
This change adds a context to all exported methods of Client
which may perform network requests, to allow for easier control
over request timeouts and cancellation.
Change-Id: I635a4d7ad39a63ed9e6823b1af12fbb201c19647
Reviewed-on: https://go-review.googlesource.com/27091
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/internal/acme/acme.go b/acme/internal/acme/acme.go
index 4255bae..ac9c5e0 100644
--- a/acme/internal/acme/acme.go
+++ b/acme/internal/acme/acme.go
@@ -78,7 +78,7 @@
// It caches successful result. So, subsequent calls will not result in
// a network round-trip. This also means mutating c.DirectoryURL after successful call
// of this method will have no effect.
-func (c *Client) Discover() (Directory, error) {
+func (c *Client) Discover(ctx context.Context) (Directory, error) {
c.dirMu.Lock()
defer c.dirMu.Unlock()
if c.dir != nil {
@@ -89,7 +89,7 @@
if dirURL == "" {
dirURL = LetsEncryptURL
}
- res, err := c.httpClient().Get(dirURL)
+ res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
if err != nil {
return Directory{}, err
}
@@ -136,7 +136,7 @@
// CreateCert returns an error if the CA's response or chain was unreasonably large.
// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features.
func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) {
- if _, err := c.Discover(); err != nil {
+ if _, err := c.Discover(ctx); err != nil {
return nil, "", err
}
@@ -155,7 +155,7 @@
req.NotAfter = now.Add(exp).Format(time.RFC3339)
}
- res, err := c.postJWS(c.dir.CertURL, req)
+ res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.CertURL, req)
if err != nil {
return nil, "", err
}
@@ -171,7 +171,7 @@
return cert, curl, err
}
// slurp issued cert and CA chain, if requested
- cert, err := responseCert(ctx, c.httpClient(), res, bundle)
+ cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
return cert, curl, err
}
@@ -186,13 +186,13 @@
// and has expected features.
func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
for {
- res, err := ctxhttp.Get(ctx, c.httpClient(), url)
+ res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
- return responseCert(ctx, c.httpClient(), res, bundle)
+ return responseCert(ctx, c.HTTPClient, res, bundle)
}
if res.StatusCode > 299 {
return nil, responseError(res)
@@ -221,13 +221,13 @@
// If so, and the account has not indicated the acceptance of the terms (see Account for details),
// Register calls prompt with a TOS URL provided by the CA. Prompt should report
// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS.
-func (c *Client) Register(a *Account, prompt func(tosURL string) bool) (*Account, error) {
- if _, err := c.Discover(); err != nil {
+func (c *Client) Register(ctx context.Context, a *Account, prompt func(tosURL string) bool) (*Account, error) {
+ if _, err := c.Discover(ctx); err != nil {
return nil, err
}
var err error
- if a, err = c.doReg(c.dir.RegURL, "new-reg", a); err != nil {
+ if a, err = c.doReg(ctx, c.dir.RegURL, "new-reg", a); err != nil {
return nil, err
}
var accept bool
@@ -236,15 +236,15 @@
}
if accept {
a.AgreedTerms = a.CurrentTerms
- a, err = c.UpdateReg(a)
+ a, err = c.UpdateReg(ctx, a)
}
return a, err
}
// GetReg retrieves an existing registration.
// The url argument is an Account URI.
-func (c *Client) GetReg(url string) (*Account, error) {
- a, err := c.doReg(url, "reg", nil)
+func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) {
+ a, err := c.doReg(ctx, url, "reg", nil)
if err != nil {
return nil, err
}
@@ -254,9 +254,9 @@
// UpdateReg updates an existing registration.
// It returns an updated account copy. The provided account is not modified.
-func (c *Client) UpdateReg(a *Account) (*Account, error) {
+func (c *Client) UpdateReg(ctx context.Context, a *Account) (*Account, error) {
uri := a.URI
- a, err := c.doReg(uri, "reg", a)
+ a, err := c.doReg(ctx, uri, "reg", a)
if err != nil {
return nil, err
}
@@ -267,8 +267,8 @@
// Authorize performs the initial step in an authorization flow.
// The caller will then need to choose from and perform a set of returned
// challenges using c.Accept in order to successfully complete authorization.
-func (c *Client) Authorize(domain string) (*Authorization, error) {
- if _, err := c.Discover(); err != nil {
+func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) {
+ if _, err := c.Discover(ctx); err != nil {
return nil, err
}
@@ -283,7 +283,7 @@
Resource: "new-authz",
Identifier: authzID{Type: "dns", Value: domain},
}
- res, err := c.postJWS(c.dir.AuthzURL, req)
+ res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.AuthzURL, req)
if err != nil {
return nil, err
}
@@ -305,8 +305,8 @@
// GetAuthz retrieves the current status of an authorization flow.
//
// A client typically polls an authz status using this method.
-func (c *Client) GetAuthz(url string) (*Authorization, error) {
- res, err := c.httpClient().Get(url)
+func (c *Client) GetAuthz(ctx context.Context, url string) (*Authorization, error) {
+ res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
if err != nil {
return nil, err
}
@@ -324,8 +324,8 @@
// GetChallenge retrieves the current status of an challenge.
//
// A client typically polls a challenge status using this method.
-func (c *Client) GetChallenge(url string) (*Challenge, error) {
- res, err := c.httpClient().Get(url)
+func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
+ res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
if err != nil {
return nil, err
}
@@ -344,7 +344,7 @@
// previously obtained with c.Authorize.
//
// The server will then perform the validation asynchronously.
-func (c *Client) Accept(chal *Challenge) (*Challenge, error) {
+func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) {
auth, err := keyAuth(c.Key.Public(), chal.Token)
if err != nil {
return nil, err
@@ -359,7 +359,7 @@
Type: chal.Type,
Auth: auth,
}
- res, err := c.postJWS(chal.URI, req)
+ res, err := postJWS(ctx, c.HTTPClient, c.Key, chal.URI, req)
if err != nil {
return nil, err
}
@@ -452,31 +452,6 @@
return cert, sanA, nil
}
-func (c *Client) httpClient() *http.Client {
- if c.HTTPClient != nil {
- return c.HTTPClient
- }
- return http.DefaultClient
-}
-
-// postJWS signs body and posts it to the provided url.
-// The body argument must be JSON-serializable.
-func (c *Client) postJWS(url string, body interface{}) (*http.Response, error) {
- nonce, err := fetchNonce(c.httpClient(), url)
- if err != nil {
- return nil, err
- }
- b, err := jwsEncodeJSON(body, c.Key, nonce)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequest("POST", url, bytes.NewReader(b))
- if err != nil {
- return nil, err
- }
- return c.httpClient().Do(req)
-}
-
// doReg sends all types of registration requests.
// The type of request is identified by typ argument, which is a "resource"
// in the ACME spec terms.
@@ -484,10 +459,7 @@
// A non-nil acct argument indicates whether the intention is to mutate data
// of the Account. Only Contact and Agreement of its fields are used
// in such cases.
-//
-// The fields of acct will be populate with the server response
-// and may be overwritten.
-func (c *Client) doReg(url string, typ string, acct *Account) (*Account, error) {
+func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Account) (*Account, error) {
req := struct {
Resource string `json:"resource"`
Contact []string `json:"contact,omitempty"`
@@ -499,7 +471,7 @@
req.Contact = acct.Contact
req.Agreement = acct.AgreedTerms
}
- res, err := c.postJWS(url, req)
+ res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req)
if err != nil {
return nil, err
}
@@ -640,8 +612,22 @@
return chain, nil
}
-func fetchNonce(client *http.Client, url string) (string, error) {
- resp, err := client.Head(url)
+// postJWS signs the body with the given key and POSTs it to the provided url.
+// The body argument must be JSON-serializable.
+func postJWS(ctx context.Context, client *http.Client, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
+ nonce, err := fetchNonce(ctx, client, url)
+ if err != nil {
+ return nil, err
+ }
+ b, err := jwsEncodeJSON(body, key, nonce)
+ if err != nil {
+ return nil, err
+ }
+ return ctxhttp.Post(ctx, client, url, "application/jose+json", bytes.NewReader(b))
+}
+
+func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
+ resp, err := ctxhttp.Head(ctx, client, url)
if err != nil {
return "", nil
}
diff --git a/acme/internal/acme/acme_test.go b/acme/internal/acme/acme_test.go
index 8b088bf..37378c5 100644
--- a/acme/internal/acme/acme_test.go
+++ b/acme/internal/acme/acme_test.go
@@ -61,7 +61,7 @@
}))
defer ts.Close()
c := Client{DirectoryURL: ts.URL}
- dir, err := c.Discover()
+ dir, err := c.Discover(context.Background())
if err != nil {
t.Fatal(err)
}
@@ -130,7 +130,7 @@
c := Client{Key: testKey, dir: &Directory{RegURL: ts.URL}}
a := &Account{Contact: contacts}
var err error
- if a, err = c.Register(a, prompt); err != nil {
+ if a, err = c.Register(context.Background(), a, prompt); err != nil {
t.Fatal(err)
}
if a.URI != "https://ca.tld/acme/reg/1" {
@@ -194,7 +194,7 @@
c := Client{Key: testKey}
a := &Account{URI: ts.URL, Contact: contacts, AgreedTerms: terms}
var err error
- if a, err = c.UpdateReg(a); err != nil {
+ if a, err = c.UpdateReg(context.Background(), a); err != nil {
t.Fatal(err)
}
if a.Authz != "https://ca.tld/acme/new-authz" {
@@ -257,7 +257,7 @@
defer ts.Close()
c := Client{Key: testKey}
- a, err := c.GetReg(ts.URL)
+ a, err := c.GetReg(context.Background(), ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -329,7 +329,7 @@
defer ts.Close()
cl := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
- auth, err := cl.Authorize("example.com")
+ auth, err := cl.Authorize(context.Background(), "example.com")
if err != nil {
t.Fatal(err)
}
@@ -408,7 +408,7 @@
defer ts.Close()
cl := Client{Key: testKey}
- auth, err := cl.GetAuthz(ts.URL)
+ auth, err := cl.GetAuthz(context.Background(), ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -471,7 +471,7 @@
defer ts.Close()
cl := Client{Key: testKey}
- chall, err := cl.GetChallenge(ts.URL)
+ chall, err := cl.GetChallenge(context.Background(), ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -532,7 +532,7 @@
defer ts.Close()
cl := Client{Key: testKey}
- c, err := cl.Accept(&Challenge{
+ c, err := cl.Accept(context.Background(), &Challenge{
URI: ts.URL,
Token: "token1",
Type: "http-01",
@@ -765,7 +765,7 @@
defer ts.Close()
for ; i < len(tests); i++ {
test := tests[i]
- n, err := fetchNonce(http.DefaultClient, ts.URL)
+ n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
if n != test.nonce {
t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
}