http2: retry requests after receiving REFUSED STREAM

RoundTrip will retry a request if it receives REFUSED_STREAM. To guard
against servers that use REFUSED_STREAM to encourage rate limiting, or
servers that return REFUSED_STREAM deterministically for some requests,
we retry after an exponential backoff and we cap the number of retries.

The exponential backoff starts on the second retry, with a backoff
sequence of 1s, 2s, 4s, etc, with 10% random jitter.

The retry cap was set to 6, somewhat arbitrarily. Rationale: this is
what Firefox does.

Updates golang/go#20985

Change-Id: I4dcac4392ac4a3220d6d839f28bf943fe6b3fea7
Reviewed-on: https://go-review.googlesource.com/50471
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 850d7ae..c98cfbc 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -18,6 +18,7 @@
 	"io/ioutil"
 	"log"
 	"math"
+	mathrand "math/rand"
 	"net"
 	"net/http"
 	"sort"
@@ -329,7 +330,7 @@
 	}
 
 	addr := authorityAddr(req.URL.Scheme, req.URL.Host)
-	for {
+	for retry := 0; ; retry++ {
 		cc, err := t.connPool().GetClientConn(req, addr)
 		if err != nil {
 			t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
@@ -337,9 +338,25 @@
 		}
 		traceGotConn(req, cc)
 		res, err := cc.RoundTrip(req)
-		if err != nil {
-			if req, err = shouldRetryRequest(req, err); err == nil {
-				continue
+		if err != nil && retry <= 6 {
+			afterBodyWrite := false
+			if e, ok := err.(afterReqBodyWriteError); ok {
+				err = e
+				afterBodyWrite = true
+			}
+			if req, err = shouldRetryRequest(req, err, afterBodyWrite); err == nil {
+				// After the first retry, do exponential backoff with 10% jitter.
+				if retry == 0 {
+					continue
+				}
+				backoff := float64(uint(1) << (uint(retry) - 1))
+				backoff += backoff * (0.1 * mathrand.Float64())
+				select {
+				case <-time.After(time.Second * time.Duration(backoff)):
+					continue
+				case <-req.Context().Done():
+					return nil, req.Context().Err()
+				}
 			}
 		}
 		if err != nil {
@@ -360,43 +377,60 @@
 }
 
 var (
-	errClientConnClosed   = errors.New("http2: client conn is closed")
-	errClientConnUnusable = errors.New("http2: client conn not usable")
-
-	errClientConnGotGoAway                 = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
-	errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
+	errClientConnClosed    = errors.New("http2: client conn is closed")
+	errClientConnUnusable  = errors.New("http2: client conn not usable")
+	errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
 )
 
+// afterReqBodyWriteError is a wrapper around errors returned by ClientConn.RoundTrip.
+// It is used to signal that err happened after part of Request.Body was sent to the server.
+type afterReqBodyWriteError struct {
+	err error
+}
+
+func (e afterReqBodyWriteError) Error() string {
+	return e.err.Error() + "; some request body already written"
+}
+
 // shouldRetryRequest is called by RoundTrip when a request fails to get
 // response headers. It is always called with a non-nil error.
 // It returns either a request to retry (either the same request, or a
 // modified clone), or an error if the request can't be replayed.
-func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) {
-	switch err {
-	default:
+func shouldRetryRequest(req *http.Request, err error, afterBodyWrite bool) (*http.Request, error) {
+	if !canRetryError(err) {
 		return nil, err
-	case errClientConnUnusable, errClientConnGotGoAway:
-		return req, nil
-	case errClientConnGotGoAwayAfterSomeReqBody:
-		// If the Body is nil (or http.NoBody), it's safe to reuse
-		// this request and its Body.
-		if req.Body == nil || reqBodyIsNoBody(req.Body) {
-			return req, nil
-		}
-		// Otherwise we depend on the Request having its GetBody
-		// func defined.
-		getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
-		if getBody == nil {
-			return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
-		}
-		body, err := getBody()
-		if err != nil {
-			return nil, err
-		}
-		newReq := *req
-		newReq.Body = body
-		return &newReq, nil
 	}
+	if !afterBodyWrite {
+		return req, nil
+	}
+	// If the Body is nil (or http.NoBody), it's safe to reuse
+	// this request and its Body.
+	if req.Body == nil || reqBodyIsNoBody(req.Body) {
+		return req, nil
+	}
+	// Otherwise we depend on the Request having its GetBody
+	// func defined.
+	getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
+	if getBody == nil {
+		return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
+	}
+	body, err := getBody()
+	if err != nil {
+		return nil, err
+	}
+	newReq := *req
+	newReq.Body = body
+	return &newReq, nil
+}
+
+func canRetryError(err error) bool {
+	if err == errClientConnUnusable || err == errClientConnGotGoAway {
+		return true
+	}
+	if se, ok := err.(StreamError); ok {
+		return se.Code == ErrCodeRefusedStream
+	}
+	return false
 }
 
 func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
@@ -816,14 +850,13 @@
 			cs.abortRequestBodyWrite(errStopReqBodyWrite)
 		}
 		if re.err != nil {
-			if re.err == errClientConnGotGoAway {
-				cc.mu.Lock()
-				if cs.startedWrite {
-					re.err = errClientConnGotGoAwayAfterSomeReqBody
-				}
-				cc.mu.Unlock()
-			}
+			cc.mu.Lock()
+			afterBodyWrite := cs.startedWrite
+			cc.mu.Unlock()
 			cc.forgetStreamID(cs.ID)
+			if afterBodyWrite {
+				return nil, afterReqBodyWriteError{re.err}
+			}
 			return nil, re.err
 		}
 		res.Request = req
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 15dfa07..0e7b801 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -2926,6 +2926,116 @@
 	}
 }
 
+func TestTransportRetryAfterRefusedStream(t *testing.T) {
+	clientDone := make(chan struct{})
+	ct := newClientTester(t)
+	ct.client = func() error {
+		defer ct.cc.(*net.TCPConn).CloseWrite()
+		defer close(clientDone)
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		resp, err := ct.tr.RoundTrip(req)
+		if err != nil {
+			return fmt.Errorf("RoundTrip: %v", err)
+		}
+		resp.Body.Close()
+		if resp.StatusCode != 204 {
+			return fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+		}
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		nreq := 0
+
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-clientDone:
+					// If the client's done, it
+					// will have reported any
+					// errors on its side.
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+				if !f.HeadersEnded() {
+					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+				}
+				nreq++
+				if nreq == 1 {
+					ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+				} else {
+					enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+					ct.fr.WriteHeaders(HeadersFrameParam{
+						StreamID:      f.StreamID,
+						EndHeaders:    true,
+						EndStream:     true,
+						BlockFragment: buf.Bytes(),
+					})
+				}
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+	}
+	ct.run()
+}
+
+func TestTransportRetryHasLimit(t *testing.T) {
+	// Skip in short mode because the total expected delay is 1s+2s+4s+8s+16s=29s.
+	if testing.Short() {
+		t.Skip("skipping long test in short mode")
+	}
+	clientDone := make(chan struct{})
+	ct := newClientTester(t)
+	ct.client = func() error {
+		defer ct.cc.(*net.TCPConn).CloseWrite()
+		defer close(clientDone)
+		req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+		resp, err := ct.tr.RoundTrip(req)
+		if err == nil {
+			return fmt.Errorf("RoundTrip expected error, got response: %+v", resp)
+		}
+		t.Logf("expected error, got: %v", err)
+		return nil
+	}
+	ct.server = func() error {
+		ct.greet()
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-clientDone:
+					// If the client's done, it
+					// will have reported any
+					// errors on its side.
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame, *SettingsFrame:
+			case *HeadersFrame:
+				if !f.HeadersEnded() {
+					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+				}
+				ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream)
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+	}
+	ct.run()
+}
+
 func TestAuthorityAddr(t *testing.T) {
 	tests := []struct {
 		scheme, authority string