http2: close request body after early RoundTrip failures

The RoundTrip contract requires that the request Body be closed,
even when an error occurs sending the request.

Fix several cases where the body was not closed by hoisting the
Close call to Transport.RoundTripOpt. Now ClientConn.roundTrip
takes responsibility for closing the body once the body write
begins; otherwise, the caller does so.

Fix the case where a new body is acquired via Request.GetBody
to close the previous body, matching net/http's behavior.

Fixes golang/go#48341.

Change-Id: Id9dc682d4d86a1c255c7c0d864208ff76ed53eb2
Reviewed-on: https://go-review.googlesource.com/c/net/+/349489
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 7d5c876..7c2c013 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -490,6 +490,7 @@
 		}
 		reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
 		traceGotConn(req, cc, reused)
+		body := req.Body
 		res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
 		if err != nil && retry <= 6 {
 			if req, err = shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
@@ -503,12 +504,17 @@
 				case <-time.After(time.Second * time.Duration(backoff)):
 					continue
 				case <-req.Context().Done():
-					return nil, req.Context().Err()
+					err = req.Context().Err()
 				}
 			}
 		}
 		if err != nil {
 			t.vlogf("RoundTrip failure: %v", err)
+			// If the error occurred after the body write started,
+			// the body writer will close the body. Otherwise, do so here.
+			if body != nil && !gotErrAfterReqBodyWrite {
+				body.Close()
+			}
 			return nil, err
 		}
 		return res, nil
@@ -547,7 +553,7 @@
 	// If the request body can be reset back to its original
 	// state via the optional req.GetBody, do that.
 	if req.GetBody != nil {
-		// TODO: consider a req.Body.Close here? or audit that all caller paths do?
+		req.Body.Close()
 		body, err := req.GetBody()
 		if err != nil {
 			return nil, err
@@ -1085,13 +1091,14 @@
 
 	if werr != nil {
 		if hasBody {
-			req.Body.Close() // per RoundTripper contract
 			bodyWriter.cancel()
 		}
 		cc.forgetStreamID(cs.ID)
 		// Don't bother sending a RST_STREAM (our write already failed;
 		// no need to keep writing)
 		traceWroteRequest(cs.trace, werr)
+		// TODO(dneil): An error occurred while writing the headers.
+		// Should we return an error indicating that this request can be retried?
 		return nil, false, werr
 	}
 
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 4412a89..efc2695 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3905,7 +3905,8 @@
 				if k >= maxConcurrent {
 					<-unblockClient
 				}
-				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+				body := newStaticCloseChecker("")
+				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), body)
 				if k == maxConcurrent {
 					// This request will be canceled.
 					cancel := make(chan struct{})
@@ -3930,6 +3931,9 @@
 						return
 					}
 				}
+				if err := body.isClosed(); err != nil {
+					errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
+				}
 			}(k)
 		}
 		return nil
@@ -3990,6 +3994,7 @@
 				if nreq == maxConcurrent+1 {
 					close(writeResp)
 				}
+			case *DataFrame:
 			default:
 				return fmt.Errorf("Unexpected client frame %v", f)
 			}
@@ -4905,11 +4910,41 @@
 	closed chan struct{}
 }
 
+func newCloseChecker(r io.ReadCloser) *closeChecker {
+	return &closeChecker{r, make(chan struct{})}
+}
+
+func newStaticCloseChecker(body string) *closeChecker {
+	return newCloseChecker(io.NopCloser(strings.NewReader("body")))
+}
+
+func (rc *closeChecker) Read(b []byte) (n int, err error) {
+	select {
+	default:
+	case <-rc.closed:
+		panic("read from closed body")
+	}
+	return rc.ReadCloser.Read(b)
+}
+
 func (rc *closeChecker) Close() error {
 	close(rc.closed)
 	return rc.ReadCloser.Close()
 }
 
+func (rc *closeChecker) isClosed() error {
+	// The RoundTrip contract says that it will close the request body,
+	// but that it may do so in a separate goroutine. Wait a reasonable
+	// amount of time before concluding that the body isn't being closed.
+	timeout := time.Duration(10 * time.Second)
+	select {
+	case <-rc.closed:
+	case <-time.After(timeout):
+		return fmt.Errorf("body not closed after %v", timeout)
+	}
+	return nil
+}
+
 func TestTransportCloseRequestBody(t *testing.T) {
 	var statusCode int
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
@@ -4929,8 +4964,8 @@
 		t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
 			statusCode = status
 			pr, pw := io.Pipe()
-			pipeClosed := make(chan struct{})
-			req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
+			body := newCloseChecker(pr)
+			req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -4940,7 +4975,9 @@
 			}
 			res.Body.Close()
 			pw.Close()
-			<-pipeClosed
+			if err := body.isClosed(); err != nil {
+				t.Fatal(err)
+			}
 		})
 	}
 }