acme: stop using ctxhttp

The ctxhttp package used to be big and gross before net/http supported
contexts natively. Nowadays it barely does anything. Stop using it,
because it just pulls in the old context package anyway. (We can't
really clean up the ctxhttp package until Go 1.9)

Change-Id: I48b11f2f483783a32cbaa75e244301148a304c08
Reviewed-on: https://go-review.googlesource.com/40110
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alex Vaghin <ddos@google.com>
diff --git a/acme/acme.go b/acme/acme.go
index 8421819..140d422 100644
--- a/acme/acme.go
+++ b/acme/acme.go
@@ -37,8 +37,6 @@
 	"strings"
 	"sync"
 	"time"
-
-	"golang.org/x/net/context/ctxhttp"
 )
 
 // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA.
@@ -133,7 +131,7 @@
 	if dirURL == "" {
 		dirURL = LetsEncryptURL
 	}
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, dirURL)
+	res, err := c.get(ctx, dirURL)
 	if err != nil {
 		return Directory{}, err
 	}
@@ -216,7 +214,7 @@
 		return cert, curl, err
 	}
 	// slurp issued cert and CA chain, if requested
-	cert, err := responseCert(ctx, c.HTTPClient, res, bundle)
+	cert, err := c.responseCert(ctx, res, bundle)
 	return cert, curl, err
 }
 
@@ -231,13 +229,13 @@
 // and has expected features.
 func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) {
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 			return nil, err
 		}
 		defer res.Body.Close()
 		if res.StatusCode == http.StatusOK {
-			return responseCert(ctx, c.HTTPClient, res, bundle)
+			return c.responseCert(ctx, res, bundle)
 		}
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
@@ -387,7 +385,7 @@
 // If a caller needs to poll an authorization until its status is final,
 // see the WaitAuthorization method.
 func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -456,7 +454,7 @@
 	}
 
 	for {
-		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		res, err := c.get(ctx, url)
 		if err != nil {
 			return nil, err
 		}
@@ -493,7 +491,7 @@
 //
 // A client typically polls a challenge status using this method.
 func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) {
-	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -708,7 +706,7 @@
 	if err != nil {
 		return nil, err
 	}
-	res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
+	res, err := c.post(ctx, url, "application/jose+json", bytes.NewReader(b))
 	if err != nil {
 		return nil, err
 	}
@@ -722,7 +720,7 @@
 	c.noncesMu.Lock()
 	defer c.noncesMu.Unlock()
 	if len(c.nonces) == 0 {
-		return fetchNonce(ctx, c.HTTPClient, url)
+		return c.fetchNonce(ctx, url)
 	}
 	var nonce string
 	for nonce = range c.nonces {
@@ -749,8 +747,58 @@
 	c.nonces[v] = struct{}{}
 }
 
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
-	resp, err := ctxhttp.Head(ctx, client, url)
+func (c *Client) httpClient() *http.Client {
+	if c.HTTPClient != nil {
+		return c.HTTPClient
+	}
+	return http.DefaultClient
+}
+
+func (c *Client) get(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("GET", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) head(ctx context.Context, urlStr string) (*http.Response, error) {
+	req, err := http.NewRequest("HEAD", urlStr, nil)
+	if err != nil {
+		return nil, err
+	}
+	return c.do(ctx, req)
+}
+
+func (c *Client) post(ctx context.Context, urlStr, contentType string, body io.Reader) (*http.Response, error) {
+	req, err := http.NewRequest("POST", urlStr, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", contentType)
+	return c.do(ctx, req)
+}
+
+func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) {
+	res, err := c.httpClient().Do(req.WithContext(ctx))
+	if err != nil {
+		select {
+		case <-ctx.Done():
+			// Prefer the unadorned context error.
+			// (The acme package had tests assuming this, previously from ctxhttp's
+			// behavior, predating net/http supporting contexts natively)
+			// TODO(bradfitz): reconsider this in the future. But for now this
+			// requires no test updates.
+			return nil, ctx.Err()
+		default:
+			return nil, err
+		}
+	}
+	return res, nil
+}
+
+func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) {
+	resp, err := c.head(ctx, url)
 	if err != nil {
 		return "", err
 	}
@@ -769,7 +817,7 @@
 	return h.Get("Replay-Nonce")
 }
 
-func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
+func (c *Client) responseCert(ctx context.Context, res *http.Response, bundle bool) ([][]byte, error) {
 	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	if err != nil {
 		return nil, fmt.Errorf("acme: response stream: %v", err)
@@ -793,7 +841,7 @@
 		return nil, errors.New("acme: rel=up link is too large")
 	}
 	for _, url := range up {
-		cc, err := chainCert(ctx, client, url, 0)
+		cc, err := c.chainCert(ctx, url, 0)
 		if err != nil {
 			return nil, err
 		}
@@ -836,12 +884,12 @@
 // if the recursion level reaches maxChainLen.
 //
 // First chainCert call starts with depth of 0.
-func chainCert(ctx context.Context, client *http.Client, url string, depth int) ([][]byte, error) {
+func (c *Client) chainCert(ctx context.Context, url string, depth int) ([][]byte, error) {
 	if depth >= maxChainLen {
 		return nil, errors.New("acme: certificate chain is too deep")
 	}
 
-	res, err := ctxhttp.Get(ctx, client, url)
+	res, err := c.get(ctx, url)
 	if err != nil {
 		return nil, err
 	}
@@ -863,7 +911,7 @@
 		return nil, errors.New("acme: certificate chain is too large")
 	}
 	for _, up := range uplink {
-		cc, err := chainCert(ctx, client, up, depth+1)
+		cc, err := c.chainCert(ctx, up, depth+1)
 		if err != nil {
 			return nil, err
 		}
diff --git a/acme/acme_test.go b/acme/acme_test.go
index e746c07..b91533d 100644
--- a/acme/acme_test.go
+++ b/acme/acme_test.go
@@ -980,7 +980,8 @@
 	defer ts.Close()
 	for ; i < len(tests); i++ {
 		test := tests[i]
-		n, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+		c := &Client{}
+		n, err := c.fetchNonce(context.Background(), ts.URL)
 		if n != test.nonce {
 			t.Errorf("%d: n=%q; want %q", i, n, test.nonce)
 		}
@@ -998,7 +999,8 @@
 		w.WriteHeader(http.StatusTooManyRequests)
 	}))
 	defer ts.Close()
-	_, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+	c := &Client{}
+	_, err := c.fetchNonce(context.Background(), ts.URL)
 	e, ok := err.(*Error)
 	if !ok {
 		t.Fatalf("err is %T; want *Error", err)