acme: add Client.WaitAuthorization for polling

It is quite often the case users need to poll an authorization until it
reaches a final state.

This change adds a handy method to do exactly that,
as well as updating autocert.Manager to use the new method.

Change-Id: I4d25e4ee751731815b77980caab6da98d3440b5d
Reviewed-on: https://go-review.googlesource.com/27431
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index af2403d..2450cb6 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -453,29 +453,8 @@
 		return err
 	}
 	// wait for the CA to validate
-	for {
-		a, err := client.GetAuthz(ctx, authz.URI)
-		if err == nil {
-			if a.Status == acme.StatusValid {
-				break
-			}
-			if a.Status == acme.StatusInvalid {
-				return fmt.Errorf("acme/autocert: validation for domain %q failed", domain)
-			}
-		}
-		// still pending
-		d := time.Second
-		if ae, ok := err.(*acme.Error); ok {
-			d = retryAfter(ae.Header.Get("retry-after"))
-		}
-		select {
-		case <-ctx.Done():
-			return ctx.Err()
-		case <-time.After(d):
-			// retry
-		}
-	}
-	return nil
+	_, err = client.WaitAuthorization(ctx, authz.URI)
+	return err
 }
 
 // certState returns existing state or creates a new one locked for read/write.
diff --git a/acme/internal/acme/acme.go b/acme/internal/acme/acme.go
index cee8f3e..9eb35a3 100644
--- a/acme/internal/acme/acme.go
+++ b/acme/internal/acme/acme.go
@@ -198,10 +198,7 @@
 		if res.StatusCode > 299 {
 			return nil, responseError(res)
 		}
-		d, err := retryAfter(res.Header.Get("retry-after"))
-		if err != nil {
-			d = 3 * time.Second
-		}
+		d := retryAfter(res.Header.Get("retry-after"), 3*time.Second)
 		select {
 		case <-time.After(d):
 			// retry
@@ -341,10 +338,11 @@
 	return v.authorization(res.Header.Get("Location")), nil
 }
 
-// GetAuthz retrieves the current status of an authorization flow.
+// GetAuthorization retrieves an authorization identified by the given URL.
 //
-// A client typically polls an authz status using this method.
-func (c *Client) GetAuthz(ctx context.Context, url string) (*Authorization, error) {
+// If a caller needs to poll an authorization until its status is final,
+// see the WaitAuthorization method.
+func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
 	res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
 	if err != nil {
 		return nil, err
@@ -360,6 +358,63 @@
 	return v.authorization(url), nil
 }
 
+// WaitAuthorization polls an authorization at the given URL
+// until it is in one of the final states, StatusValid or StatusInvalid,
+// or the context is done.
+//
+// It returns a non-nil Authorization only if its Status is StatusValid.
+// In all other cases WaitAuthorization returns an error.
+// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed.
+func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
+	var count int
+	sleep := func(v string, inc int) error {
+		count += inc
+		d := backoff(count, 10*time.Second)
+		d = retryAfter(v, d)
+		wakeup := time.NewTimer(d)
+		defer wakeup.Stop()
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-wakeup.C:
+			return nil
+		}
+	}
+
+	for {
+		res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+		if err != nil {
+			return nil, err
+		}
+		retry := res.Header.Get("retry-after")
+		if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
+			res.Body.Close()
+			if err := sleep(retry, 1); err != nil {
+				return nil, err
+			}
+			continue
+		}
+		var raw wireAuthz
+		err = json.NewDecoder(res.Body).Decode(&raw)
+		res.Body.Close()
+		if err != nil {
+			if err := sleep(retry, 0); err != nil {
+				return nil, err
+			}
+			continue
+		}
+		if raw.Status == StatusValid {
+			return raw.authorization(url), nil
+		}
+		if raw.Status == StatusInvalid {
+			return nil, ErrAuthorizationFailed
+		}
+		if err := sleep(retry, 0); err != nil {
+			return nil, err
+		}
+	}
+}
+
 // GetChallenge retrieves the current status of an challenge.
 //
 // A client typically polls a challenge status using this method.
@@ -699,15 +754,41 @@
 	return links
 }
 
-func retryAfter(v string) (time.Duration, error) {
+// retryAfter parses a Retry-After HTTP header value,
+// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
+// It returns d if v cannot be parsed.
+func retryAfter(v string, d time.Duration) time.Duration {
 	if i, err := strconv.Atoi(v); err == nil {
-		return time.Duration(i) * time.Second, nil
+		return time.Duration(i) * time.Second
 	}
 	t, err := http.ParseTime(v)
 	if err != nil {
-		return 0, err
+		return d
 	}
-	return t.Sub(timeNow()), nil
+	return t.Sub(timeNow())
+}
+
+// backoff computes a duration after which an n+1 retry iteration should occur
+// using truncated exponential backoff algorithm.
+//
+// The n argument is always bounded between 0 and 30.
+// The max argument defines upper bound for the returned value.
+func backoff(n int, max time.Duration) time.Duration {
+	if n < 0 {
+		n = 0
+	}
+	if n > 30 {
+		n = 30
+	}
+	var d time.Duration
+	if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
+		d = time.Duration(x.Int64()) * time.Millisecond
+	}
+	d += time.Duration(1<<uint(n)) * time.Second
+	if d > max {
+		return max
+	}
+	return d
 }
 
 // keyAuth generates a key authorization string for a given token.
diff --git a/acme/internal/acme/acme_test.go b/acme/internal/acme/acme_test.go
index fc54079..364cd68 100644
--- a/acme/internal/acme/acme_test.go
+++ b/acme/internal/acme/acme_test.go
@@ -385,7 +385,7 @@
 	}
 }
 
-func TestPollAuthz(t *testing.T) {
+func TestGetAuthorization(t *testing.T) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Method != "GET" {
 			t.Errorf("r.Method = %q; want GET", r.Method)
@@ -414,7 +414,7 @@
 	defer ts.Close()
 
 	cl := Client{Key: testKeyEC}
-	auth, err := cl.GetAuthz(context.Background(), ts.URL)
+	auth, err := cl.GetAuthorization(context.Background(), ts.URL)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -461,6 +461,95 @@
 	}
 }
 
+func TestWaitAuthorization(t *testing.T) {
+	var count int
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		count++
+		w.Header().Set("retry-after", "0")
+		if count > 1 {
+			fmt.Fprintf(w, `{"status":"valid"}`)
+			return
+		}
+		fmt.Fprintf(w, `{"status":"pending"}`)
+	}))
+	defer ts.Close()
+
+	type res struct {
+		authz *Authorization
+		err   error
+	}
+	done := make(chan res)
+	defer close(done)
+	go func() {
+		var client Client
+		a, err := client.WaitAuthorization(context.Background(), ts.URL)
+		done <- res{a, err}
+	}()
+
+	select {
+	case <-time.After(5 * time.Second):
+		t.Fatal("WaitAuthz took too long to return")
+	case res := <-done:
+		if res.err != nil {
+			t.Fatalf("res.err =  %v", res.err)
+		}
+		if res.authz == nil {
+			t.Fatal("res.authz is nil")
+		}
+	}
+}
+
+func TestWaitAuthorizationInvalid(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		fmt.Fprintf(w, `{"status":"invalid"}`)
+	}))
+	defer ts.Close()
+
+	res := make(chan error)
+	defer close(res)
+	go func() {
+		var client Client
+		_, err := client.WaitAuthorization(context.Background(), ts.URL)
+		res <- err
+	}()
+
+	select {
+	case <-time.After(3 * time.Second):
+		t.Fatal("WaitAuthz took too long to return")
+	case err := <-res:
+		if err == nil {
+			t.Error("err is nil")
+		}
+	}
+}
+
+func TestWaitAuthorizationCancel(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.Header().Set("retry-after", "60")
+		fmt.Fprintf(w, `{"status":"pending"}`)
+	}))
+	defer ts.Close()
+
+	res := make(chan error)
+	defer close(res)
+	go func() {
+		var client Client
+		ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+		defer cancel()
+		_, err := client.WaitAuthorization(ctx, ts.URL)
+		res <- err
+	}()
+
+	select {
+	case <-time.After(time.Second):
+		t.Fatal("WaitAuthz took too long to return")
+	case err := <-res:
+		if err == nil {
+			t.Error("err is nil")
+		}
+	}
+}
+
 func TestPollChallenge(t *testing.T) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Method != "GET" {
@@ -954,3 +1043,28 @@
 		t.Errorf("path = %q; want %q", path, urlpath)
 	}
 }
+
+func TestBackoff(t *testing.T) {
+	tt := []struct{ min, max time.Duration }{
+		{time.Second, 2 * time.Second},
+		{2 * time.Second, 3 * time.Second},
+		{4 * time.Second, 5 * time.Second},
+		{8 * time.Second, 9 * time.Second},
+	}
+	for i, test := range tt {
+		d := backoff(i, time.Minute)
+		if d < test.min || test.max < d {
+			t.Errorf("%d: d = %v; want between %v and %v", i, d, test.min, test.max)
+		}
+	}
+
+	min, max := time.Second, 2*time.Second
+	if d := backoff(-1, time.Minute); d < min || max < d {
+		t.Errorf("d = %v; want between %v and %v", d, min, max)
+	}
+
+	bound := 10 * time.Second
+	if d := backoff(100, bound); d != bound {
+		t.Errorf("d = %v; want %v", d, bound)
+	}
+}
diff --git a/acme/internal/acme/types.go b/acme/internal/acme/types.go
index 612cd01..0513b2e 100644
--- a/acme/internal/acme/types.go
+++ b/acme/internal/acme/types.go
@@ -33,8 +33,14 @@
 	CRLReasonAACompromise         CRLReasonCode = 10
 )
 
-// ErrUnsupportedKey is returned when an unsupported key type is encountered.
-var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
+var (
+	// ErrAuthorizationFailed indicates that an authorization for an identifier
+	// did not succeed.
+	ErrAuthorizationFailed = errors.New("acme: identifier authorization failed")
+
+	// ErrUnsupportedKey is returned when an unsupported key type is encountered.
+	ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
+)
 
 // Error is an ACME error, defined in Problem Details for HTTP APIs doc
 // http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.