http2: close the Request's Body when aborting a stream

After RoundTrip returns, closing the Response's Body should
interrupt any ongoing write of the request body. Close the
Request's Body to unblock the body writer.

Take additional care around the use of a Request after
its Response's Body has been closed. The RoundTripper contract
permits the caller to modify the request after the Response's
body has been closed.

Updates golang/go#48908.

Change-Id: I261e08eb5d70016b49942d72833f46b2ae83962a
Reviewed-on: https://go-review.googlesource.com/c/net/+/355491
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index a5ba742..2ff6544 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -300,12 +300,17 @@
 // clientStream is the state for a single HTTP/2 stream. One of these
 // is created for each Transport.RoundTrip call.
 type clientStream struct {
-	cc            *ClientConn
-	req           *http.Request
+	cc *ClientConn
+
+	// Fields of Request that we may access even after the response body is closed.
+	ctx       context.Context
+	reqCancel <-chan struct{}
+
 	trace         *httptrace.ClientTrace // or nil
 	ID            uint32
 	bufPipe       pipe // buffered pipe with the flow-controlled response payload
 	requestedGzip bool
+	isHead        bool
 
 	abortOnce sync.Once
 	abort     chan struct{} // closed to signal stream should end immediately
@@ -322,7 +327,10 @@
 	inflow      flow  // guarded by cc.mu
 	bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
 	readErr     error // sticky read error; owned by transportResponseBody.Read
-	stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
+
+	reqBody              io.ReadCloser
+	reqBodyContentLength int64 // -1 means unknown
+	reqBodyClosed        bool  // body has been closed; guarded by cc.mu
 
 	// owned by writeRequest:
 	sentEndStream bool // sent an END_STREAM flag to the peer
@@ -362,6 +370,10 @@
 		cs.abortErr = err
 		close(cs.abort)
 	})
+	if cs.reqBody != nil && !cs.reqBodyClosed {
+		cs.reqBody.Close()
+		cs.reqBodyClosed = true
+	}
 	// TODO(dneil): Clean up tests where cs.cc.cond is nil.
 	if cs.cc.cond != nil {
 		// Wake up writeRequestBody if it is waiting on flow control.
@@ -369,17 +381,15 @@
 	}
 }
 
-func (cs *clientStream) abortRequestBodyWrite(err error) {
-	if err == nil {
-		panic("nil error")
-	}
+func (cs *clientStream) abortRequestBodyWrite() {
 	cc := cs.cc
 	cc.mu.Lock()
-	if cs.stopReqBody == nil {
-		cs.stopReqBody = err
+	defer cc.mu.Unlock()
+	if cs.reqBody != nil && !cs.reqBodyClosed {
+		cs.reqBody.Close()
+		cs.reqBodyClosed = true
 		cc.cond.Broadcast()
 	}
-	cc.mu.Unlock()
 }
 
 type stickyErrWriter struct {
@@ -1065,15 +1075,19 @@
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	ctx := req.Context()
 	cs := &clientStream{
-		cc:             cc,
-		req:            req,
-		trace:          httptrace.ContextClientTrace(req.Context()),
-		peerClosed:     make(chan struct{}),
-		abort:          make(chan struct{}),
-		respHeaderRecv: make(chan struct{}),
-		donec:          make(chan struct{}),
+		cc:                   cc,
+		ctx:                  ctx,
+		reqCancel:            req.Cancel,
+		isHead:               req.Method == "HEAD",
+		reqBody:              req.Body,
+		reqBodyContentLength: actualContentLength(req),
+		trace:                httptrace.ContextClientTrace(ctx),
+		peerClosed:           make(chan struct{}),
+		abort:                make(chan struct{}),
+		respHeaderRecv:       make(chan struct{}),
+		donec:                make(chan struct{}),
 	}
-	go cs.doRequest()
+	go cs.doRequest(req)
 
 	waitDone := func() error {
 		select {
@@ -1081,7 +1095,7 @@
 			return nil
 		case <-ctx.Done():
 			return ctx.Err()
-		case <-req.Cancel:
+		case <-cs.reqCancel:
 			return errRequestCanceled
 		}
 	}
@@ -1100,7 +1114,7 @@
 				// doesn't, they'll RST_STREAM us soon enough. This is a
 				// heuristic to avoid adding knobs to Transport. Hopefully
 				// we can keep it.
-				cs.abortRequestBodyWrite(errStopReqBodyWrite)
+				cs.abortRequestBodyWrite()
 			}
 			res.Request = req
 			res.TLS = cc.tlsState
@@ -1117,8 +1131,11 @@
 			waitDone()
 			return nil, cs.abortErr
 		case <-ctx.Done():
-			return nil, ctx.Err()
-		case <-req.Cancel:
+			err := ctx.Err()
+			cs.abortStream(err)
+			return nil, err
+		case <-cs.reqCancel:
+			cs.abortStream(errRequestCanceled)
 			return nil, errRequestCanceled
 		}
 	}
@@ -1127,8 +1144,8 @@
 // doRequest runs for the duration of the request lifetime.
 //
 // It sends the request and performs post-request cleanup (closing Request.Body, etc.).
-func (cs *clientStream) doRequest() {
-	err := cs.writeRequest()
+func (cs *clientStream) doRequest(req *http.Request) {
+	err := cs.writeRequest(req)
 	cs.cleanupWriteRequest(err)
 }
 
@@ -1139,12 +1156,11 @@
 //
 // It returns non-nil if the request ends otherwise.
 // If the returned error is StreamError, the error Code may be used in resetting the stream.
-func (cs *clientStream) writeRequest() (err error) {
+func (cs *clientStream) writeRequest(req *http.Request) (err error) {
 	cc := cs.cc
-	req := cs.req
-	ctx := req.Context()
+	ctx := cs.ctx
 
-	if err := checkConnHeaders(cs.req); err != nil {
+	if err := checkConnHeaders(req); err != nil {
 		return err
 	}
 
@@ -1156,7 +1172,7 @@
 	}
 	select {
 	case cc.reqHeaderMu <- struct{}{}:
-	case <-req.Cancel:
+	case <-cs.reqCancel:
 		return errRequestCanceled
 	case <-ctx.Done():
 		return ctx.Err()
@@ -1179,7 +1195,7 @@
 	if !cc.t.disableCompression() &&
 		req.Header.Get("Accept-Encoding") == "" &&
 		req.Header.Get("Range") == "" &&
-		req.Method != "HEAD" {
+		!cs.isHead {
 		// Request gzip only, not deflate. Deflate is ambiguous and
 		// not as universally supported anyway.
 		// See: https://zlib.net/zlib_faq.html#faq39
@@ -1198,19 +1214,23 @@
 	continueTimeout := cc.t.expectContinueTimeout()
 	if continueTimeout != 0 &&
 		!httpguts.HeaderValuesContainsToken(
-			cs.req.Header["Expect"],
+			req.Header["Expect"],
 			"100-continue") {
 		continueTimeout = 0
 		cs.on100 = make(chan struct{}, 1)
 	}
 
-	err = cs.encodeAndWriteHeaders()
+	// Past this point (where we send request headers), it is possible for
+	// RoundTrip to return successfully. Since the RoundTrip contract permits
+	// the caller to "mutate or reuse" the Request after closing the Response's Body,
+	// we must take care when referencing the Request from here on.
+	err = cs.encodeAndWriteHeaders(req)
 	<-cc.reqHeaderMu
 	if err != nil {
 		return err
 	}
 
-	hasBody := actualContentLength(cs.req) != 0
+	hasBody := cs.reqBodyContentLength != 0
 	if !hasBody {
 		cs.sentEndStream = true
 	} else {
@@ -1226,7 +1246,7 @@
 				err = cs.abortErr
 			case <-ctx.Done():
 				err = ctx.Err()
-			case <-req.Cancel:
+			case <-cs.reqCancel:
 				err = errRequestCanceled
 			}
 			timer.Stop()
@@ -1236,7 +1256,7 @@
 			}
 		}
 
-		if err = cs.writeRequestBody(req.Body); err != nil {
+		if err = cs.writeRequestBody(req); err != nil {
 			if err != errStopReqBodyWrite {
 				traceWroteRequest(cs.trace, err)
 				return err
@@ -1271,16 +1291,15 @@
 			return cs.abortErr
 		case <-ctx.Done():
 			return ctx.Err()
-		case <-req.Cancel:
+		case <-cs.reqCancel:
 			return errRequestCanceled
 		}
 	}
 }
 
-func (cs *clientStream) encodeAndWriteHeaders() error {
+func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error {
 	cc := cs.cc
-	req := cs.req
-	ctx := req.Context()
+	ctx := cs.ctx
 
 	cc.wmu.Lock()
 	defer cc.wmu.Unlock()
@@ -1291,7 +1310,7 @@
 		return cs.abortErr
 	case <-ctx.Done():
 		return ctx.Err()
-	case <-req.Cancel:
+	case <-cs.reqCancel:
 		return errRequestCanceled
 	default:
 	}
@@ -1301,14 +1320,14 @@
 	// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
 	// sent by writeRequestBody below, along with any Trailers,
 	// again in form HEADERS{1}, CONTINUATION{0,})
-	trailers, err := commaSeparatedTrailers(cs.req)
+	trailers, err := commaSeparatedTrailers(req)
 	if err != nil {
 		return err
 	}
 	hasTrailers := trailers != ""
-	contentLen := actualContentLength(cs.req)
+	contentLen := actualContentLength(req)
 	hasBody := contentLen != 0
-	hdrs, err := cc.encodeHeaders(cs.req, cs.requestedGzip, trailers, contentLen)
+	hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
 	if err != nil {
 		return err
 	}
@@ -1327,7 +1346,6 @@
 // cleanupWriteRequest will send a reset to the peer.
 func (cs *clientStream) cleanupWriteRequest(err error) {
 	cc := cs.cc
-	req := cs.req
 
 	if cs.ID == 0 {
 		// We were canceled before creating the stream, so return our reservation.
@@ -1338,10 +1356,12 @@
 	// Request.Body is closed by the Transport,
 	// and in multiple cases: server replies <=299 and >299
 	// while still writing request body
-	if req.Body != nil {
-		if e := req.Body.Close(); err == nil {
-			err = e
-		}
+	cc.mu.Lock()
+	bodyClosed := cs.reqBodyClosed
+	cs.reqBodyClosed = true
+	cc.mu.Unlock()
+	if !bodyClosed && cs.reqBody != nil {
+		cs.reqBody.Close()
 	}
 
 	if err != nil && cs.sentEndStream {
@@ -1456,7 +1476,7 @@
 	if n > max {
 		n = max
 	}
-	if cl := actualContentLength(cs.req); cl != -1 && cl+1 < n {
+	if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n {
 		// Add an extra byte past the declared content-length to
 		// give the caller's Request.Body io.Reader a chance to
 		// give us more bytes than they declared, so we can catch it
@@ -1471,13 +1491,13 @@
 
 var bufPool sync.Pool // of *[]byte
 
-func (cs *clientStream) writeRequestBody(body io.Reader) (err error) {
+func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
 	cc := cs.cc
+	body := cs.reqBody
 	sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
 
-	req := cs.req
 	hasTrailers := req.Trailer != nil
-	remainLen := actualContentLength(req)
+	remainLen := cs.reqBodyContentLength
 	hasContentLen := remainLen != -1
 
 	cc.mu.Lock()
@@ -1529,12 +1549,7 @@
 		for len(remain) > 0 && err == nil {
 			var allowed int32
 			allowed, err = cs.awaitFlowControl(len(remain))
-			switch {
-			case err == errStopReqBodyWrite:
-				return err
-			case err == errStopReqBodyWriteAndCancel:
-				return err
-			case err != nil:
+			if err != nil {
 				return err
 			}
 			cc.wmu.Lock()
@@ -1565,16 +1580,26 @@
 		return nil
 	}
 
+	// Since the RoundTrip contract permits the caller to "mutate or reuse"
+	// a request after the Response's Body is closed, verify that this hasn't
+	// happened before accessing the trailers.
+	cc.mu.Lock()
+	trailer := req.Trailer
+	err = cs.abortErr
+	cc.mu.Unlock()
+	if err != nil {
+		return err
+	}
+
 	cc.wmu.Lock()
+	defer cc.wmu.Unlock()
 	var trls []byte
-	if hasTrailers {
-		trls, err = cc.encodeTrailers(req)
+	if len(trailer) > 0 {
+		trls, err = cc.encodeTrailers(trailer)
 		if err != nil {
-			cc.wmu.Unlock()
 			return err
 		}
 	}
-	defer cc.wmu.Unlock()
 
 	// Two ways to send END_STREAM: either with trailers, or
 	// with an empty DATA frame.
@@ -1595,23 +1620,22 @@
 // if the stream is dead.
 func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
 	cc := cs.cc
-	req := cs.req
-	ctx := req.Context()
+	ctx := cs.ctx
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	for {
 		if cc.closed {
 			return 0, errClientConnClosed
 		}
-		if cs.stopReqBody != nil {
-			return 0, cs.stopReqBody
+		if cs.reqBodyClosed {
+			return 0, errStopReqBodyWrite
 		}
 		select {
 		case <-cs.abort:
 			return 0, cs.abortErr
 		case <-ctx.Done():
 			return 0, ctx.Err()
-		case <-req.Cancel:
+		case <-cs.reqCancel:
 			return 0, errRequestCanceled
 		default:
 		}
@@ -1825,11 +1849,11 @@
 }
 
 // requires cc.wmu be held.
-func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) {
+func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
 	cc.hbuf.Reset()
 
 	hlSize := uint64(0)
-	for k, vv := range req.Trailer {
+	for k, vv := range trailer {
 		for _, v := range vv {
 			hf := hpack.HeaderField{Name: k, Value: v}
 			hlSize += uint64(hf.Size())
@@ -1839,7 +1863,7 @@
 		return nil, errRequestHeaderListSize
 	}
 
-	for k, vv := range req.Trailer {
+	for k, vv := range trailer {
 		lowKey, ascii := asciiToLower(k)
 		if !ascii {
 			// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
@@ -2230,7 +2254,7 @@
 		// more safe smuggling-wise to ignore.
 	}
 
-	if cs.req.Method == "HEAD" {
+	if cs.isHead {
 		res.Body = noBody
 		return res, nil
 	}
@@ -2287,8 +2311,7 @@
 }
 
 // transportResponseBody is the concrete type of Transport.RoundTrip's
-// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body.
-// On Close it sends RST_STREAM if EOF wasn't already seen.
+// Response.Body. It is an io.ReadCloser.
 type transportResponseBody struct {
 	cs *clientStream
 }
@@ -2371,6 +2394,8 @@
 		}
 		cc.mu.Unlock()
 
+		// TODO(dneil): Acquiring this mutex can block indefinitely.
+		// Move flow control return to a goroutine?
 		cc.wmu.Lock()
 		// Return connection-level flow control.
 		if unread > 0 {
@@ -2385,9 +2410,9 @@
 
 	select {
 	case <-cs.donec:
-	case <-cs.req.Context().Done():
-		return cs.req.Context().Err()
-	case <-cs.req.Cancel:
+	case <-cs.ctx.Done():
+		return cs.ctx.Err()
+	case <-cs.reqCancel:
 		return errRequestCanceled
 	}
 	return nil
@@ -2441,7 +2466,7 @@
 		return nil
 	}
 	if f.Length > 0 {
-		if cs.req.Method == "HEAD" && len(data) > 0 {
+		if cs.isHead && len(data) > 0 {
 			cc.logf("protocol error: received DATA on a HEAD request")
 			rl.endStreamError(cs, StreamError{
 				StreamID: f.StreamID,
diff --git a/http2/transport_test.go b/http2/transport_test.go
index b250738..322a4c4 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5676,3 +5676,28 @@
 		})
 	}
 }
+
+func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(200)
+		w.(http.Flusher).Flush()
+		io.Copy(io.Discard, r.Body)
+	}, optOnlyServer)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+	defer tr.CloseIdleConnections()
+
+	pr, pw := net.Pipe()
+	req, err := http.NewRequest("GET", st.ts.URL, pr)
+	if err != nil {
+		t.Fatal(err)
+	}
+	res, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	// Closing the Response's Body interrupts the blocked body read.
+	res.Body.Close()
+	pw.Close()
+}