acme: add Client.WaitAuthorization for polling
It is quite often the case users need to poll an authorization until it
reaches a final state.
This change adds a handy method to do exactly that,
as well as updating autocert.Manager to use the new method.
Change-Id: I4d25e4ee751731815b77980caab6da98d3440b5d
Reviewed-on: https://go-review.googlesource.com/27431
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/autocert/autocert.go b/acme/autocert/autocert.go
index af2403d..2450cb6 100644
--- a/acme/autocert/autocert.go
+++ b/acme/autocert/autocert.go
@@ -453,29 +453,8 @@
return err
}
// wait for the CA to validate
- for {
- a, err := client.GetAuthz(ctx, authz.URI)
- if err == nil {
- if a.Status == acme.StatusValid {
- break
- }
- if a.Status == acme.StatusInvalid {
- return fmt.Errorf("acme/autocert: validation for domain %q failed", domain)
- }
- }
- // still pending
- d := time.Second
- if ae, ok := err.(*acme.Error); ok {
- d = retryAfter(ae.Header.Get("retry-after"))
- }
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-time.After(d):
- // retry
- }
- }
- return nil
+ _, err = client.WaitAuthorization(ctx, authz.URI)
+ return err
}
// certState returns existing state or creates a new one locked for read/write.
diff --git a/acme/internal/acme/acme.go b/acme/internal/acme/acme.go
index cee8f3e..9eb35a3 100644
--- a/acme/internal/acme/acme.go
+++ b/acme/internal/acme/acme.go
@@ -198,10 +198,7 @@
if res.StatusCode > 299 {
return nil, responseError(res)
}
- d, err := retryAfter(res.Header.Get("retry-after"))
- if err != nil {
- d = 3 * time.Second
- }
+ d := retryAfter(res.Header.Get("retry-after"), 3*time.Second)
select {
case <-time.After(d):
// retry
@@ -341,10 +338,11 @@
return v.authorization(res.Header.Get("Location")), nil
}
-// GetAuthz retrieves the current status of an authorization flow.
+// GetAuthorization retrieves an authorization identified by the given URL.
//
-// A client typically polls an authz status using this method.
-func (c *Client) GetAuthz(ctx context.Context, url string) (*Authorization, error) {
+// If a caller needs to poll an authorization until its status is final,
+// see the WaitAuthorization method.
+func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) {
res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
if err != nil {
return nil, err
@@ -360,6 +358,63 @@
return v.authorization(url), nil
}
+// WaitAuthorization polls an authorization at the given URL
+// until it is in one of the final states, StatusValid or StatusInvalid,
+// or the context is done.
+//
+// It returns a non-nil Authorization only if its Status is StatusValid.
+// In all other cases WaitAuthorization returns an error.
+// If the Status is StatusInvalid, the returned error is ErrAuthorizationFailed.
+func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) {
+ var count int
+ sleep := func(v string, inc int) error {
+ count += inc
+ d := backoff(count, 10*time.Second)
+ d = retryAfter(v, d)
+ wakeup := time.NewTimer(d)
+ defer wakeup.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-wakeup.C:
+ return nil
+ }
+ }
+
+ for {
+ res, err := ctxhttp.Get(ctx, c.HTTPClient, url)
+ if err != nil {
+ return nil, err
+ }
+ retry := res.Header.Get("retry-after")
+ if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusAccepted {
+ res.Body.Close()
+ if err := sleep(retry, 1); err != nil {
+ return nil, err
+ }
+ continue
+ }
+ var raw wireAuthz
+ err = json.NewDecoder(res.Body).Decode(&raw)
+ res.Body.Close()
+ if err != nil {
+ if err := sleep(retry, 0); err != nil {
+ return nil, err
+ }
+ continue
+ }
+ if raw.Status == StatusValid {
+ return raw.authorization(url), nil
+ }
+ if raw.Status == StatusInvalid {
+ return nil, ErrAuthorizationFailed
+ }
+ if err := sleep(retry, 0); err != nil {
+ return nil, err
+ }
+ }
+}
+
// GetChallenge retrieves the current status of an challenge.
//
// A client typically polls a challenge status using this method.
@@ -699,15 +754,41 @@
return links
}
-func retryAfter(v string) (time.Duration, error) {
+// retryAfter parses a Retry-After HTTP header value,
+// trying to convert v into an int (seconds) or use http.ParseTime otherwise.
+// It returns d if v cannot be parsed.
+func retryAfter(v string, d time.Duration) time.Duration {
if i, err := strconv.Atoi(v); err == nil {
- return time.Duration(i) * time.Second, nil
+ return time.Duration(i) * time.Second
}
t, err := http.ParseTime(v)
if err != nil {
- return 0, err
+ return d
}
- return t.Sub(timeNow()), nil
+ return t.Sub(timeNow())
+}
+
+// backoff computes a duration after which an n+1 retry iteration should occur
+// using truncated exponential backoff algorithm.
+//
+// The n argument is always bounded between 0 and 30.
+// The max argument defines upper bound for the returned value.
+func backoff(n int, max time.Duration) time.Duration {
+ if n < 0 {
+ n = 0
+ }
+ if n > 30 {
+ n = 30
+ }
+ var d time.Duration
+ if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil {
+ d = time.Duration(x.Int64()) * time.Millisecond
+ }
+ d += time.Duration(1<<uint(n)) * time.Second
+ if d > max {
+ return max
+ }
+ return d
}
// keyAuth generates a key authorization string for a given token.
diff --git a/acme/internal/acme/acme_test.go b/acme/internal/acme/acme_test.go
index fc54079..364cd68 100644
--- a/acme/internal/acme/acme_test.go
+++ b/acme/internal/acme/acme_test.go
@@ -385,7 +385,7 @@
}
}
-func TestPollAuthz(t *testing.T) {
+func TestGetAuthorization(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("r.Method = %q; want GET", r.Method)
@@ -414,7 +414,7 @@
defer ts.Close()
cl := Client{Key: testKeyEC}
- auth, err := cl.GetAuthz(context.Background(), ts.URL)
+ auth, err := cl.GetAuthorization(context.Background(), ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -461,6 +461,95 @@
}
}
+func TestWaitAuthorization(t *testing.T) {
+ var count int
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ count++
+ w.Header().Set("retry-after", "0")
+ if count > 1 {
+ fmt.Fprintf(w, `{"status":"valid"}`)
+ return
+ }
+ fmt.Fprintf(w, `{"status":"pending"}`)
+ }))
+ defer ts.Close()
+
+ type res struct {
+ authz *Authorization
+ err error
+ }
+ done := make(chan res)
+ defer close(done)
+ go func() {
+ var client Client
+ a, err := client.WaitAuthorization(context.Background(), ts.URL)
+ done <- res{a, err}
+ }()
+
+ select {
+ case <-time.After(5 * time.Second):
+ t.Fatal("WaitAuthz took too long to return")
+ case res := <-done:
+ if res.err != nil {
+ t.Fatalf("res.err = %v", res.err)
+ }
+ if res.authz == nil {
+ t.Fatal("res.authz is nil")
+ }
+ }
+}
+
+func TestWaitAuthorizationInvalid(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintf(w, `{"status":"invalid"}`)
+ }))
+ defer ts.Close()
+
+ res := make(chan error)
+ defer close(res)
+ go func() {
+ var client Client
+ _, err := client.WaitAuthorization(context.Background(), ts.URL)
+ res <- err
+ }()
+
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("WaitAuthz took too long to return")
+ case err := <-res:
+ if err == nil {
+ t.Error("err is nil")
+ }
+ }
+}
+
+func TestWaitAuthorizationCancel(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("retry-after", "60")
+ fmt.Fprintf(w, `{"status":"pending"}`)
+ }))
+ defer ts.Close()
+
+ res := make(chan error)
+ defer close(res)
+ go func() {
+ var client Client
+ ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancel()
+ _, err := client.WaitAuthorization(ctx, ts.URL)
+ res <- err
+ }()
+
+ select {
+ case <-time.After(time.Second):
+ t.Fatal("WaitAuthz took too long to return")
+ case err := <-res:
+ if err == nil {
+ t.Error("err is nil")
+ }
+ }
+}
+
func TestPollChallenge(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
@@ -954,3 +1043,28 @@
t.Errorf("path = %q; want %q", path, urlpath)
}
}
+
+func TestBackoff(t *testing.T) {
+ tt := []struct{ min, max time.Duration }{
+ {time.Second, 2 * time.Second},
+ {2 * time.Second, 3 * time.Second},
+ {4 * time.Second, 5 * time.Second},
+ {8 * time.Second, 9 * time.Second},
+ }
+ for i, test := range tt {
+ d := backoff(i, time.Minute)
+ if d < test.min || test.max < d {
+ t.Errorf("%d: d = %v; want between %v and %v", i, d, test.min, test.max)
+ }
+ }
+
+ min, max := time.Second, 2*time.Second
+ if d := backoff(-1, time.Minute); d < min || max < d {
+ t.Errorf("d = %v; want between %v and %v", d, min, max)
+ }
+
+ bound := 10 * time.Second
+ if d := backoff(100, bound); d != bound {
+ t.Errorf("d = %v; want %v", d, bound)
+ }
+}
diff --git a/acme/internal/acme/types.go b/acme/internal/acme/types.go
index 612cd01..0513b2e 100644
--- a/acme/internal/acme/types.go
+++ b/acme/internal/acme/types.go
@@ -33,8 +33,14 @@
CRLReasonAACompromise CRLReasonCode = 10
)
-// ErrUnsupportedKey is returned when an unsupported key type is encountered.
-var ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
+var (
+ // ErrAuthorizationFailed indicates that an authorization for an identifier
+ // did not succeed.
+ ErrAuthorizationFailed = errors.New("acme: identifier authorization failed")
+
+ // ErrUnsupportedKey is returned when an unsupported key type is encountered.
+ ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported")
+)
// Error is an ACME error, defined in Problem Details for HTTP APIs doc
// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem.