http2: don't return from RoundTrip until request body is closed
Moving the Request.Body.Close call out from the ClientConn mutex
results in some cases where RoundTrip returns while the Close is
still in progress. This should be legal (RoundTrip explicitly allows
for this), but net/http relies on Close never being called after
RoundTrip returns.
Add additional synchronization to ensure Close calls complete
before RoundTrip returns.
Fixes golang/go#55896
Change-Id: Ie3d4773966745e83987d219927929cb56ec1a7ad
Reviewed-on: https://go-review.googlesource.com/c/net/+/435535
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 9a874f7..52991f3 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -345,8 +345,8 @@
readErr error // sticky read error; owned by transportResponseBody.Read
reqBody io.ReadCloser
- reqBodyContentLength int64 // -1 means unknown
- reqBodyClosed bool // body has been closed; guarded by cc.mu
+ reqBodyContentLength int64 // -1 means unknown
+ reqBodyClosed chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
// owned by writeRequest:
sentEndStream bool // sent an END_STREAM flag to the peer
@@ -376,46 +376,48 @@
}
func (cs *clientStream) abortStream(err error) {
- var reqBody io.ReadCloser
- defer func() {
- if reqBody != nil {
- reqBody.Close()
- }
- }()
cs.cc.mu.Lock()
defer cs.cc.mu.Unlock()
- reqBody = cs.abortStreamLocked(err)
+ cs.abortStreamLocked(err)
}
-func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser {
+func (cs *clientStream) abortStreamLocked(err error) {
cs.abortOnce.Do(func() {
cs.abortErr = err
close(cs.abort)
})
- var reqBody io.ReadCloser
- if cs.reqBody != nil && !cs.reqBodyClosed {
- cs.reqBodyClosed = true
- reqBody = cs.reqBody
+ if cs.reqBody != nil {
+ cs.closeReqBodyLocked()
}
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
cs.cc.cond.Broadcast()
}
- return reqBody
}
func (cs *clientStream) abortRequestBodyWrite() {
cc := cs.cc
cc.mu.Lock()
defer cc.mu.Unlock()
- if cs.reqBody != nil && !cs.reqBodyClosed {
- cs.reqBody.Close()
- cs.reqBodyClosed = true
+ if cs.reqBody != nil && cs.reqBodyClosed == nil {
+ cs.closeReqBodyLocked()
cc.cond.Broadcast()
}
}
+func (cs *clientStream) closeReqBodyLocked() {
+ if cs.reqBodyClosed != nil {
+ return
+ }
+ cs.reqBodyClosed = make(chan struct{})
+ reqBodyClosed := cs.reqBodyClosed
+ go func() {
+ cs.reqBody.Close()
+ close(reqBodyClosed)
+ }()
+}
+
type stickyErrWriter struct {
conn net.Conn
timeout time.Duration
@@ -771,12 +773,6 @@
}
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
- var reqBodiesToClose []io.ReadCloser
- defer func() {
- for _, reqBody := range reqBodiesToClose {
- reqBody.Close()
- }
- }()
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -793,10 +789,7 @@
last := f.LastStreamID
for streamID, cs := range cc.streams {
if streamID > last {
- reqBody := cs.abortStreamLocked(errClientConnGotGoAway)
- if reqBody != nil {
- reqBodiesToClose = append(reqBodiesToClose, reqBody)
- }
+ cs.abortStreamLocked(errClientConnGotGoAway)
}
}
}
@@ -1049,19 +1042,11 @@
func (cc *ClientConn) closeForError(err error) {
cc.mu.Lock()
cc.closed = true
-
- var reqBodiesToClose []io.ReadCloser
for _, cs := range cc.streams {
- reqBody := cs.abortStreamLocked(err)
- if reqBody != nil {
- reqBodiesToClose = append(reqBodiesToClose, reqBody)
- }
+ cs.abortStreamLocked(err)
}
cc.cond.Broadcast()
cc.mu.Unlock()
- for _, reqBody := range reqBodiesToClose {
- reqBody.Close()
- }
cc.closeConn()
}
@@ -1458,11 +1443,19 @@
// and in multiple cases: server replies <=299 and >299
// while still writing request body
cc.mu.Lock()
+ mustCloseBody := false
+ if cs.reqBody != nil && cs.reqBodyClosed == nil {
+ mustCloseBody = true
+ cs.reqBodyClosed = make(chan struct{})
+ }
bodyClosed := cs.reqBodyClosed
- cs.reqBodyClosed = true
cc.mu.Unlock()
- if !bodyClosed && cs.reqBody != nil {
+ if mustCloseBody {
cs.reqBody.Close()
+ close(bodyClosed)
+ }
+ if bodyClosed != nil {
+ <-bodyClosed
}
if err != nil && cs.sentEndStream {
@@ -1642,7 +1635,7 @@
}
if err != nil {
cc.mu.Lock()
- bodyClosed := cs.reqBodyClosed
+ bodyClosed := cs.reqBodyClosed != nil
cc.mu.Unlock()
switch {
case bodyClosed:
@@ -1737,7 +1730,7 @@
if cc.closed {
return 0, errClientConnClosed
}
- if cs.reqBodyClosed {
+ if cs.reqBodyClosed != nil {
return 0, errStopReqBodyWrite
}
select {
@@ -2110,24 +2103,17 @@
}
cc.closed = true
- var reqBodiesToClose []io.ReadCloser
for _, cs := range cc.streams {
select {
case <-cs.peerClosed:
// The server closed the stream before closing the conn,
// so no need to interrupt it.
default:
- reqBody := cs.abortStreamLocked(err)
- if reqBody != nil {
- reqBodiesToClose = append(reqBodiesToClose, reqBody)
- }
+ cs.abortStreamLocked(err)
}
}
cc.cond.Broadcast()
cc.mu.Unlock()
- for _, reqBody := range reqBodiesToClose {
- reqBody.Close()
- }
}
// countReadFrameError calls Transport.CountError with a string