http2: close the Request's Body when aborting a stream
After RoundTrip returns, closing the Response's Body should
interrupt any ongoing write of the request body. Close the
Request's Body to unblock the body writer.
Take additional care around the use of a Request after
its Response's Body has been closed. The RoundTripper contract
permits the caller to modify the request after the Response's
body has been closed.
Updates golang/go#48908.
Change-Id: I261e08eb5d70016b49942d72833f46b2ae83962a
Reviewed-on: https://go-review.googlesource.com/c/net/+/355491
Trust: Damien Neil <dneil@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index a5ba742..2ff6544 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -300,12 +300,17 @@
// clientStream is the state for a single HTTP/2 stream. One of these
// is created for each Transport.RoundTrip call.
type clientStream struct {
- cc *ClientConn
- req *http.Request
+ cc *ClientConn
+
+ // Fields of Request that we may access even after the response body is closed.
+ ctx context.Context
+ reqCancel <-chan struct{}
+
trace *httptrace.ClientTrace // or nil
ID uint32
bufPipe pipe // buffered pipe with the flow-controlled response payload
requestedGzip bool
+ isHead bool
abortOnce sync.Once
abort chan struct{} // closed to signal stream should end immediately
@@ -322,7 +327,10 @@
inflow flow // guarded by cc.mu
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
- stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
+
+ reqBody io.ReadCloser
+ reqBodyContentLength int64 // -1 means unknown
+ reqBodyClosed bool // body has been closed; guarded by cc.mu
// owned by writeRequest:
sentEndStream bool // sent an END_STREAM flag to the peer
@@ -362,6 +370,10 @@
cs.abortErr = err
close(cs.abort)
})
+ if cs.reqBody != nil && !cs.reqBodyClosed {
+ cs.reqBody.Close()
+ cs.reqBodyClosed = true
+ }
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
@@ -369,17 +381,15 @@
}
}
-func (cs *clientStream) abortRequestBodyWrite(err error) {
- if err == nil {
- panic("nil error")
- }
+func (cs *clientStream) abortRequestBodyWrite() {
cc := cs.cc
cc.mu.Lock()
- if cs.stopReqBody == nil {
- cs.stopReqBody = err
+ defer cc.mu.Unlock()
+ if cs.reqBody != nil && !cs.reqBodyClosed {
+ cs.reqBody.Close()
+ cs.reqBodyClosed = true
cc.cond.Broadcast()
}
- cc.mu.Unlock()
}
type stickyErrWriter struct {
@@ -1065,15 +1075,19 @@
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
cs := &clientStream{
- cc: cc,
- req: req,
- trace: httptrace.ContextClientTrace(req.Context()),
- peerClosed: make(chan struct{}),
- abort: make(chan struct{}),
- respHeaderRecv: make(chan struct{}),
- donec: make(chan struct{}),
+ cc: cc,
+ ctx: ctx,
+ reqCancel: req.Cancel,
+ isHead: req.Method == "HEAD",
+ reqBody: req.Body,
+ reqBodyContentLength: actualContentLength(req),
+ trace: httptrace.ContextClientTrace(ctx),
+ peerClosed: make(chan struct{}),
+ abort: make(chan struct{}),
+ respHeaderRecv: make(chan struct{}),
+ donec: make(chan struct{}),
}
- go cs.doRequest()
+ go cs.doRequest(req)
waitDone := func() error {
select {
@@ -1081,7 +1095,7 @@
return nil
case <-ctx.Done():
return ctx.Err()
- case <-req.Cancel:
+ case <-cs.reqCancel:
return errRequestCanceled
}
}
@@ -1100,7 +1114,7 @@
// doesn't, they'll RST_STREAM us soon enough. This is a
// heuristic to avoid adding knobs to Transport. Hopefully
// we can keep it.
- cs.abortRequestBodyWrite(errStopReqBodyWrite)
+ cs.abortRequestBodyWrite()
}
res.Request = req
res.TLS = cc.tlsState
@@ -1117,8 +1131,11 @@
waitDone()
return nil, cs.abortErr
case <-ctx.Done():
- return nil, ctx.Err()
- case <-req.Cancel:
+ err := ctx.Err()
+ cs.abortStream(err)
+ return nil, err
+ case <-cs.reqCancel:
+ cs.abortStream(errRequestCanceled)
return nil, errRequestCanceled
}
}
@@ -1127,8 +1144,8 @@
// doRequest runs for the duration of the request lifetime.
//
// It sends the request and performs post-request cleanup (closing Request.Body, etc.).
-func (cs *clientStream) doRequest() {
- err := cs.writeRequest()
+func (cs *clientStream) doRequest(req *http.Request) {
+ err := cs.writeRequest(req)
cs.cleanupWriteRequest(err)
}
@@ -1139,12 +1156,11 @@
//
// It returns non-nil if the request ends otherwise.
// If the returned error is StreamError, the error Code may be used in resetting the stream.
-func (cs *clientStream) writeRequest() (err error) {
+func (cs *clientStream) writeRequest(req *http.Request) (err error) {
cc := cs.cc
- req := cs.req
- ctx := req.Context()
+ ctx := cs.ctx
- if err := checkConnHeaders(cs.req); err != nil {
+ if err := checkConnHeaders(req); err != nil {
return err
}
@@ -1156,7 +1172,7 @@
}
select {
case cc.reqHeaderMu <- struct{}{}:
- case <-req.Cancel:
+ case <-cs.reqCancel:
return errRequestCanceled
case <-ctx.Done():
return ctx.Err()
@@ -1179,7 +1195,7 @@
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
req.Header.Get("Range") == "" &&
- req.Method != "HEAD" {
+ !cs.isHead {
// Request gzip only, not deflate. Deflate is ambiguous and
// not as universally supported anyway.
// See: https://zlib.net/zlib_faq.html#faq39
@@ -1198,19 +1214,23 @@
continueTimeout := cc.t.expectContinueTimeout()
if continueTimeout != 0 &&
!httpguts.HeaderValuesContainsToken(
- cs.req.Header["Expect"],
+ req.Header["Expect"],
"100-continue") {
continueTimeout = 0
cs.on100 = make(chan struct{}, 1)
}
- err = cs.encodeAndWriteHeaders()
+ // Past this point (where we send request headers), it is possible for
+ // RoundTrip to return successfully. Since the RoundTrip contract permits
+ // the caller to "mutate or reuse" the Request after closing the Response's Body,
+ // we must take care when referencing the Request from here on.
+ err = cs.encodeAndWriteHeaders(req)
<-cc.reqHeaderMu
if err != nil {
return err
}
- hasBody := actualContentLength(cs.req) != 0
+ hasBody := cs.reqBodyContentLength != 0
if !hasBody {
cs.sentEndStream = true
} else {
@@ -1226,7 +1246,7 @@
err = cs.abortErr
case <-ctx.Done():
err = ctx.Err()
- case <-req.Cancel:
+ case <-cs.reqCancel:
err = errRequestCanceled
}
timer.Stop()
@@ -1236,7 +1256,7 @@
}
}
- if err = cs.writeRequestBody(req.Body); err != nil {
+ if err = cs.writeRequestBody(req); err != nil {
if err != errStopReqBodyWrite {
traceWroteRequest(cs.trace, err)
return err
@@ -1271,16 +1291,15 @@
return cs.abortErr
case <-ctx.Done():
return ctx.Err()
- case <-req.Cancel:
+ case <-cs.reqCancel:
return errRequestCanceled
}
}
}
-func (cs *clientStream) encodeAndWriteHeaders() error {
+func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error {
cc := cs.cc
- req := cs.req
- ctx := req.Context()
+ ctx := cs.ctx
cc.wmu.Lock()
defer cc.wmu.Unlock()
@@ -1291,7 +1310,7 @@
return cs.abortErr
case <-ctx.Done():
return ctx.Err()
- case <-req.Cancel:
+ case <-cs.reqCancel:
return errRequestCanceled
default:
}
@@ -1301,14 +1320,14 @@
// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
// sent by writeRequestBody below, along with any Trailers,
// again in form HEADERS{1}, CONTINUATION{0,})
- trailers, err := commaSeparatedTrailers(cs.req)
+ trailers, err := commaSeparatedTrailers(req)
if err != nil {
return err
}
hasTrailers := trailers != ""
- contentLen := actualContentLength(cs.req)
+ contentLen := actualContentLength(req)
hasBody := contentLen != 0
- hdrs, err := cc.encodeHeaders(cs.req, cs.requestedGzip, trailers, contentLen)
+ hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
if err != nil {
return err
}
@@ -1327,7 +1346,6 @@
// cleanupWriteRequest will send a reset to the peer.
func (cs *clientStream) cleanupWriteRequest(err error) {
cc := cs.cc
- req := cs.req
if cs.ID == 0 {
// We were canceled before creating the stream, so return our reservation.
@@ -1338,10 +1356,12 @@
// Request.Body is closed by the Transport,
// and in multiple cases: server replies <=299 and >299
// while still writing request body
- if req.Body != nil {
- if e := req.Body.Close(); err == nil {
- err = e
- }
+ cc.mu.Lock()
+ bodyClosed := cs.reqBodyClosed
+ cs.reqBodyClosed = true
+ cc.mu.Unlock()
+ if !bodyClosed && cs.reqBody != nil {
+ cs.reqBody.Close()
}
if err != nil && cs.sentEndStream {
@@ -1456,7 +1476,7 @@
if n > max {
n = max
}
- if cl := actualContentLength(cs.req); cl != -1 && cl+1 < n {
+ if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n {
// Add an extra byte past the declared content-length to
// give the caller's Request.Body io.Reader a chance to
// give us more bytes than they declared, so we can catch it
@@ -1471,13 +1491,13 @@
var bufPool sync.Pool // of *[]byte
-func (cs *clientStream) writeRequestBody(body io.Reader) (err error) {
+func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
cc := cs.cc
+ body := cs.reqBody
sentEnd := false // whether we sent the final DATA frame w/ END_STREAM
- req := cs.req
hasTrailers := req.Trailer != nil
- remainLen := actualContentLength(req)
+ remainLen := cs.reqBodyContentLength
hasContentLen := remainLen != -1
cc.mu.Lock()
@@ -1529,12 +1549,7 @@
for len(remain) > 0 && err == nil {
var allowed int32
allowed, err = cs.awaitFlowControl(len(remain))
- switch {
- case err == errStopReqBodyWrite:
- return err
- case err == errStopReqBodyWriteAndCancel:
- return err
- case err != nil:
+ if err != nil {
return err
}
cc.wmu.Lock()
@@ -1565,16 +1580,26 @@
return nil
}
+ // Since the RoundTrip contract permits the caller to "mutate or reuse"
+ // a request after the Response's Body is closed, verify that this hasn't
+ // happened before accessing the trailers.
+ cc.mu.Lock()
+ trailer := req.Trailer
+ err = cs.abortErr
+ cc.mu.Unlock()
+ if err != nil {
+ return err
+ }
+
cc.wmu.Lock()
+ defer cc.wmu.Unlock()
var trls []byte
- if hasTrailers {
- trls, err = cc.encodeTrailers(req)
+ if len(trailer) > 0 {
+ trls, err = cc.encodeTrailers(trailer)
if err != nil {
- cc.wmu.Unlock()
return err
}
}
- defer cc.wmu.Unlock()
// Two ways to send END_STREAM: either with trailers, or
// with an empty DATA frame.
@@ -1595,23 +1620,22 @@
// if the stream is dead.
func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
cc := cs.cc
- req := cs.req
- ctx := req.Context()
+ ctx := cs.ctx
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if cc.closed {
return 0, errClientConnClosed
}
- if cs.stopReqBody != nil {
- return 0, cs.stopReqBody
+ if cs.reqBodyClosed {
+ return 0, errStopReqBodyWrite
}
select {
case <-cs.abort:
return 0, cs.abortErr
case <-ctx.Done():
return 0, ctx.Err()
- case <-req.Cancel:
+ case <-cs.reqCancel:
return 0, errRequestCanceled
default:
}
@@ -1825,11 +1849,11 @@
}
// requires cc.wmu be held.
-func (cc *ClientConn) encodeTrailers(req *http.Request) ([]byte, error) {
+func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
cc.hbuf.Reset()
hlSize := uint64(0)
- for k, vv := range req.Trailer {
+ for k, vv := range trailer {
for _, v := range vv {
hf := hpack.HeaderField{Name: k, Value: v}
hlSize += uint64(hf.Size())
@@ -1839,7 +1863,7 @@
return nil, errRequestHeaderListSize
}
- for k, vv := range req.Trailer {
+ for k, vv := range trailer {
lowKey, ascii := asciiToLower(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
@@ -2230,7 +2254,7 @@
// more safe smuggling-wise to ignore.
}
- if cs.req.Method == "HEAD" {
+ if cs.isHead {
res.Body = noBody
return res, nil
}
@@ -2287,8 +2311,7 @@
}
// transportResponseBody is the concrete type of Transport.RoundTrip's
-// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body.
-// On Close it sends RST_STREAM if EOF wasn't already seen.
+// Response.Body. It is an io.ReadCloser.
type transportResponseBody struct {
cs *clientStream
}
@@ -2371,6 +2394,8 @@
}
cc.mu.Unlock()
+ // TODO(dneil): Acquiring this mutex can block indefinitely.
+ // Move flow control return to a goroutine?
cc.wmu.Lock()
// Return connection-level flow control.
if unread > 0 {
@@ -2385,9 +2410,9 @@
select {
case <-cs.donec:
- case <-cs.req.Context().Done():
- return cs.req.Context().Err()
- case <-cs.req.Cancel:
+ case <-cs.ctx.Done():
+ return cs.ctx.Err()
+ case <-cs.reqCancel:
return errRequestCanceled
}
return nil
@@ -2441,7 +2466,7 @@
return nil
}
if f.Length > 0 {
- if cs.req.Method == "HEAD" && len(data) > 0 {
+ if cs.isHead && len(data) > 0 {
cc.logf("protocol error: received DATA on a HEAD request")
rl.endStreamError(cs, StreamError{
StreamID: f.StreamID,
diff --git a/http2/transport_test.go b/http2/transport_test.go
index b250738..322a4c4 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5676,3 +5676,28 @@
})
}
}
+
+func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ w.(http.Flusher).Flush()
+ io.Copy(io.Discard, r.Body)
+ }, optOnlyServer)
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure}
+ defer tr.CloseIdleConnections()
+
+ pr, pw := net.Pipe()
+ req, err := http.NewRequest("GET", st.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Closing the Response's Body interrupts the blocked body read.
+ res.Body.Close()
+ pw.Close()
+}