acme: context-aware Client methods

This change adds a context to all exported methods of Client
which may perform network requests, to allow for easier control
over request timeouts and cancellation.

Change-Id: I635a4d7ad39a63ed9e6823b1af12fbb201c19647
Reviewed-on: https://go-review.googlesource.com/27091
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/internal/acme/acme.go b/acme/internal/acme/acme.go
index 4255bae..ac9c5e0 100644
--- a/acme/internal/acme/acme.go
+++ b/acme/internal/acme/acme.go
@@ -78,7 +78,7 @@
 // It caches successful result. So, subsequent calls will not result in
 // a network round-trip. This also means mutating c.DirectoryURL after successful call
 // of this method will have no effect.
-func (c *Client) Discover() (Directory, error) {
+func (c *Client) Discover(ctx context.Context) (Directory, error) {
 	c.dirMu.Lock()
 	defer c.dirMu.Unlock()
 	if c.dir != nil {
@@ -89,7 +89,7 @@
 	if dirURL == "" {
 		dirURL = LetsEncryptURL
 	}
-	res, err := c.httpClient().Get(dirURL)
+	res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
 	if err != nil {
 		return Directory{}, err
 	}
@@ -136,7 +136,7 @@
 // CreateCert returns an error if the CA's response or chain was unreasonably large.
 // Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features.
 func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) {
-	if _, err := c.Discover(); err != nil {
+	if _, err := c.Discover(ctx); err != nil {
 		return nil, "", err
 	}
 
@@ -155,7 +155,7 @@
 		req.NotAfter = now.Add(exp).Format(time.RFC3339)
 	}
 
-	res, err := c.postJWS(c.dir.CertURL, req)
+	res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.CertURL, req)
 	if err != nil {
 		return nil, "", err
 	}
@@ -171,7 +171,7 @@
 		return cert, curl, err
 	}
 	// slurp issued cert and CA chain, if requested
-	cert, err := responseCert(ctx, c.httpClient(), res, bundle)
+	cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
 	return cert, curl, err
 }
 
@@ -186,13 +186,13 @@
 // and has expected features.
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 	for {
-		res, err := ctxhttp.Get(ctx, c.httpClient(), url)
+		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
 		if err != nil {
 			return nil, err
 		}
 		defer res.Body.Close()
 		if res.StatusCode == http.StatusOK {
-			return responseCert(ctx, c.httpClient(), res, bundle)
+			return responseCert(ctx, c.HTTPClient, res, bundle)
 		}
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
@@ -221,13 +221,13 @@
 // If so, and the account has not indicated the acceptance of the terms (see Account for details),
 // Register calls prompt with a TOS URL provided by the CA. Prompt should report
 // whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS.
-func (c *Client) Register(a *Account, prompt func(tosURL string) bool) (*Account, error) {
-	if _, err := c.Discover(); err != nil {
+func (c *Client) Register(ctx context.Context, a *Account, prompt func(tosURL string) bool) (*Account, error) {
+	if _, err := c.Discover(ctx); err != nil {
 		return nil, err
 	}
 
 	var err error
-	if a, err = c.doReg(c.dir.RegURL, "new-reg", a); err != nil {
+	if a, err = c.doReg(ctx, c.dir.RegURL, "new-reg", a); err != nil {
 		return nil, err
 	}
 	var accept bool
@@ -236,15 +236,15 @@
 	}
 	if accept {
 		a.AgreedTerms = a.CurrentTerms
-		a, err = c.UpdateReg(a)
+		a, err = c.UpdateReg(ctx, a)
 	}
 	return a, err
 }
 
 // GetReg retrieves an existing registration.
 // The url argument is an Account URI.
-func (c *Client) GetReg(url string) (*Account, error) {
-	a, err := c.doReg(url, "reg", nil)
+func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) {
+	a, err := c.doReg(ctx, url, "reg", nil)
 	if err != nil {
 		return nil, err
 	}
@@ -254,9 +254,9 @@
 
 // UpdateReg updates an existing registration.
 // It returns an updated account copy. The provided account is not modified.
-func (c *Client) UpdateReg(a *Account) (*Account, error) {
+func (c *Client) UpdateReg(ctx context.Context, a *Account) (*Account, error) {
 	uri := a.URI
-	a, err := c.doReg(uri, "reg", a)
+	a, err := c.doReg(ctx, uri, "reg", a)
 	if err != nil {
 		return nil, err
 	}
@@ -267,8 +267,8 @@
 // Authorize performs the initial step in an authorization flow.
 // The caller will then need to choose from and perform a set of returned
 // challenges using c.Accept in order to successfully complete authorization.
-func (c *Client) Authorize(domain string) (*Authorization, error) {
-	if _, err := c.Discover(); err != nil {
+func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) {
+	if _, err := c.Discover(ctx); err != nil {
 		return nil, err
 	}
 
@@ -283,7 +283,7 @@
 		Resource:   "new-authz",
 		Identifier: authzID{Type: "dns", Value: domain},
 	}
-	res, err := c.postJWS(c.dir.AuthzURL, req)
+	res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.AuthzURL, req)
 	if err != nil {
 		return nil, err
 	}
@@ -305,8 +305,8 @@
 // GetAuthz retrieves the current status of an authorization flow.
 //
 // A client typically polls an authz status using this method.
-func (c *Client) GetAuthz(url string) (*Authorization, error) {
-	res, err := c.httpClient().Get(url)
+func (c *Client) GetAuthz(ctx context.Context, url string) (*Authorization, error) {
+	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
 	if err != nil {
 		return nil, err
 	}
@@ -324,8 +324,8 @@
 // GetChallenge retrieves the current status of an challenge.
 //
 // A client typically polls a challenge status using this method.
-func (c *Client) GetChallenge(url string) (*Challenge, error) {
-	res, err := c.httpClient().Get(url)
+func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
+	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
 	if err != nil {
 		return nil, err
 	}
@@ -344,7 +344,7 @@
 // previously obtained with c.Authorize.
 //
 // The server will then perform the validation asynchronously.
-func (c *Client) Accept(chal *Challenge) (*Challenge, error) {
+func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) {
 	auth, err := keyAuth(c.Key.Public(), chal.Token)
 	if err != nil {
 		return nil, err
@@ -359,7 +359,7 @@
 		Type:     chal.Type,
 		Auth:     auth,
 	}
-	res, err := c.postJWS(chal.URI, req)
+	res, err := postJWS(ctx, c.HTTPClient, c.Key, chal.URI, req)
 	if err != nil {
 		return nil, err
 	}
@@ -452,31 +452,6 @@
 	return cert, sanA, nil
 }
 
-func (c *Client) httpClient() *http.Client {
-	if c.HTTPClient != nil {
-		return c.HTTPClient
-	}
-	return http.DefaultClient
-}
-
-// postJWS signs body and posts it to the provided url.
-// The body argument must be JSON-serializable.
-func (c *Client) postJWS(url string, body interface{}) (*http.Response, error) {
-	nonce, err := fetchNonce(c.httpClient(), url)
-	if err != nil {
-		return nil, err
-	}
-	b, err := jwsEncodeJSON(body, c.Key, nonce)
-	if err != nil {
-		return nil, err
-	}
-	req, err := http.NewRequest("POST", url, bytes.NewReader(b))
-	if err != nil {
-		return nil, err
-	}
-	return c.httpClient().Do(req)
-}
-
 // doReg sends all types of registration requests.
 // The type of request is identified by typ argument, which is a "resource"
 // in the ACME spec terms.
@@ -484,10 +459,7 @@
 // A non-nil acct argument indicates whether the intention is to mutate data
 // of the Account. Only Contact and Agreement of its fields are used
 // in such cases.
-//
-// The fields of acct will be populate with the server response
-// and may be overwritten.
-func (c *Client) doReg(url string, typ string, acct *Account) (*Account, error) {
+func (c *Client) doReg(ctx context.Context, url string, typ string, acct *Account) (*Account, error) {
 	req := struct {
 		Resource  string   `json:"resource"`
 		Contact   []string `json:"contact,omitempty"`
@@ -499,7 +471,7 @@
 		req.Contact = acct.Contact
 		req.Agreement = acct.AgreedTerms
 	}
-	res, err := c.postJWS(url, req)
+	res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req)
 	if err != nil {
 		return nil, err
 	}
@@ -640,8 +612,22 @@
 	return chain, nil
 }
 
-func fetchNonce(client *http.Client, url string) (string, error) {
-	resp, err := client.Head(url)
+// postJWS signs the body with the given key and POSTs it to the provided url.
+// The body argument must be JSON-serializable.
+func postJWS(ctx context.Context, client *http.Client, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
+	nonce, err := fetchNonce(ctx, client, url)
+	if err != nil {
+		return nil, err
+	}
+	b, err := jwsEncodeJSON(body, key, nonce)
+	if err != nil {
+		return nil, err
+	}
+	return ctxhttp.Post(ctx, client, url, "application/jose+json", bytes.NewReader(b))
+}
+
+func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
+	resp, err := ctxhttp.Head(ctx, client, url)
 	if err != nil {
 		return "", nil
 	}
diff --git a/acme/internal/acme/acme_test.go b/acme/internal/acme/acme_test.go
index 8b088bf..37378c5 100644
--- a/acme/internal/acme/acme_test.go
+++ b/acme/internal/acme/acme_test.go
@@ -61,7 +61,7 @@
 	}))
 	defer ts.Close()
 	c := Client{DirectoryURL: ts.URL}
-	dir, err := c.Discover()
+	dir, err := c.Discover(context.Background())
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -130,7 +130,7 @@
 	c := Client{Key: testKey, dir: &Directory{RegURL: ts.URL}}
 	a := &Account{Contact: contacts}
 	var err error
-	if a, err = c.Register(a, prompt); err != nil {
+	if a, err = c.Register(context.Background(), a, prompt); err != nil {
 		t.Fatal(err)
 	}
 	if a.URI != "https://ca.tld/acme/reg/1" {
@@ -194,7 +194,7 @@
 	c := Client{Key: testKey}
 	a := &Account{URI: ts.URL, Contact: contacts, AgreedTerms: terms}
 	var err error
-	if a, err = c.UpdateReg(a); err != nil {
+	if a, err = c.UpdateReg(context.Background(), a); err != nil {
 		t.Fatal(err)
 	}
 	if a.Authz != "https://ca.tld/acme/new-authz" {
@@ -257,7 +257,7 @@
 	defer ts.Close()
 
 	c := Client{Key: testKey}
-	a, err := c.GetReg(ts.URL)
+	a, err := c.GetReg(context.Background(), ts.URL)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -329,7 +329,7 @@
 	defer ts.Close()
 
 	cl := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
-	auth, err := cl.Authorize("example.com")
+	auth, err := cl.Authorize(context.Background(), "example.com")
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -408,7 +408,7 @@
 	defer ts.Close()
 
 	cl := Client{Key: testKey}
-	auth, err := cl.GetAuthz(ts.URL)
+	auth, err := cl.GetAuthz(context.Background(), ts.URL)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -471,7 +471,7 @@
 	defer ts.Close()
 
 	cl := Client{Key: testKey}
-	chall, err := cl.GetChallenge(ts.URL)
+	chall, err := cl.GetChallenge(context.Background(), ts.URL)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -532,7 +532,7 @@
 	defer ts.Close()
 
 	cl := Client{Key: testKey}
-	c, err := cl.Accept(&Challenge{
+	c, err := cl.Accept(context.Background(), &Challenge{
 		URI:   ts.URL,
 		Token: "token1",
 		Type:  "http-01",
@@ -765,7 +765,7 @@
 	defer ts.Close()
 	for ; i < len(tests); i++ {
 		test := tests[i]
-		n, err := fetchNonce(http.DefaultClient, ts.URL)
+		n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
 		if n != test.nonce {
 			t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
 		}