http2: close the request body if needed

As per client.Do and Request.Body, the transport is responsible to close
the request Body.
If there was an error or non 1xx/2xx status code, the transport will
wait for the body writer to complete. If there is no data available to
read, the body writer will block indefinitely. To prevent this, the body
will be closed if it hasn't already.
If there was a 1xx/2xx status code, the body will be closed eventually.

Updates golang/go#43989

Change-Id: I9a4a5f13658122c562baf915e2c0c8992a023278
Reviewed-on: https://go-review.googlesource.com/c/net/+/323689
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Alexander Rakoczy <alex@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index b97adff..b261beb 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -385,8 +385,13 @@
 	}
 	cc := cs.cc
 	cc.mu.Lock()
-	cs.stopReqBody = err
-	cc.cond.Broadcast()
+	if cs.stopReqBody == nil {
+		cs.stopReqBody = err
+		if cs.req.Body != nil {
+			cs.req.Body.Close()
+		}
+		cc.cond.Broadcast()
+	}
 	cc.mu.Unlock()
 }
 
@@ -1110,40 +1115,28 @@
 		return res, false, nil
 	}
 
+	handleError := func(err error) (*http.Response, bool, error) {
+		if !hasBody || bodyWritten {
+			cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+		} else {
+			bodyWriter.cancel()
+			cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
+			<-bodyWriter.resc
+		}
+		cc.forgetStreamID(cs.ID)
+		return nil, cs.getStartedWrite(), err
+	}
+
 	for {
 		select {
 		case re := <-readLoopResCh:
 			return handleReadLoopResponse(re)
 		case <-respHeaderTimer:
-			if !hasBody || bodyWritten {
-				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
-			} else {
-				bodyWriter.cancel()
-				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
-				<-bodyWriter.resc
-			}
-			cc.forgetStreamID(cs.ID)
-			return nil, cs.getStartedWrite(), errTimeout
+			return handleError(errTimeout)
 		case <-ctx.Done():
-			if !hasBody || bodyWritten {
-				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
-			} else {
-				bodyWriter.cancel()
-				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
-				<-bodyWriter.resc
-			}
-			cc.forgetStreamID(cs.ID)
-			return nil, cs.getStartedWrite(), ctx.Err()
+			return handleError(ctx.Err())
 		case <-req.Cancel:
-			if !hasBody || bodyWritten {
-				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
-			} else {
-				bodyWriter.cancel()
-				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
-				<-bodyWriter.resc
-			}
-			cc.forgetStreamID(cs.ID)
-			return nil, cs.getStartedWrite(), errRequestCanceled
+			return handleError(errRequestCanceled)
 		case <-cs.peerReset:
 			// processResetStream already removed the
 			// stream from the streams map; no need for
@@ -1290,7 +1283,13 @@
 		// Request.Body is closed by the Transport,
 		// and in multiple cases: server replies <=299 and >299
 		// while still writing request body
-		cerr := bodyCloser.Close()
+		var cerr error
+		cc.mu.Lock()
+		if cs.stopReqBody == nil {
+			cs.stopReqBody = errStopReqBodyWrite
+			cerr = bodyCloser.Close()
+		}
+		cc.mu.Unlock()
 		if err == nil {
 			err = cerr
 		}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 750813b..2da7d9d 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -4899,3 +4899,48 @@
 	}
 	res.Body.Close()
 }
+
+type closeChecker struct {
+	io.ReadCloser
+	closed chan struct{}
+}
+
+func (rc *closeChecker) Close() error {
+	close(rc.closed)
+	return rc.ReadCloser.Close()
+}
+
+func TestTransportCloseRequestBody(t *testing.T) {
+	var statusCode int
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(statusCode)
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+	ctx := context.Background()
+	cc, err := tr.dialClientConn(ctx, st.ts.Listener.Addr().String(), false)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for _, status := range []int{200, 401} {
+		t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
+			statusCode = status
+			pr, pw := io.Pipe()
+			pipeClosed := make(chan struct{})
+			req, err := http.NewRequest("PUT", "https://dummy.tld/", &closeChecker{pr, pipeClosed})
+			if err != nil {
+				t.Fatal(err)
+			}
+			res, err := cc.RoundTrip(req)
+			if err != nil {
+				t.Fatal(err)
+			}
+			res.Body.Close()
+			pw.Close()
+			<-pipeClosed
+		})
+	}
+}