http2: don't return from RoundTrip until request body is closed

Moving the Request.Body.Close call out from the ClientConn mutex
results in some cases where RoundTrip returns while the Close is
still in progress. This should be legal (RoundTrip explicitly allows
for this), but net/http relies on Close never being called after
RoundTrip returns.

Add additional synchronization to ensure Close calls complete
before RoundTrip returns.

Fixes golang/go#55896

Change-Id: Ie3d4773966745e83987d219927929cb56ec1a7ad
Reviewed-on: https://go-review.googlesource.com/c/net/+/435535
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 9a874f7..52991f3 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -345,8 +345,8 @@
 	readErr     error // sticky read error; owned by transportResponseBody.Read
 
 	reqBody              io.ReadCloser
-	reqBodyContentLength int64 // -1 means unknown
-	reqBodyClosed        bool  // body has been closed; guarded by cc.mu
+	reqBodyContentLength int64         // -1 means unknown
+	reqBodyClosed        chan struct{} // guarded by cc.mu; non-nil on Close, closed when done
 
 	// owned by writeRequest:
 	sentEndStream bool // sent an END_STREAM flag to the peer
@@ -376,46 +376,48 @@
 }
 
 func (cs *clientStream) abortStream(err error) {
-	var reqBody io.ReadCloser
-	defer func() {
-		if reqBody != nil {
-			reqBody.Close()
-		}
-	}()
 	cs.cc.mu.Lock()
 	defer cs.cc.mu.Unlock()
-	reqBody = cs.abortStreamLocked(err)
+	cs.abortStreamLocked(err)
 }
 
-func (cs *clientStream) abortStreamLocked(err error) io.ReadCloser {
+func (cs *clientStream) abortStreamLocked(err error) {
 	cs.abortOnce.Do(func() {
 		cs.abortErr = err
 		close(cs.abort)
 	})
-	var reqBody io.ReadCloser
-	if cs.reqBody != nil && !cs.reqBodyClosed {
-		cs.reqBodyClosed = true
-		reqBody = cs.reqBody
+	if cs.reqBody != nil {
+		cs.closeReqBodyLocked()
 	}
 	// TODO(dneil): Clean up tests where cs.cc.cond is nil.
 	if cs.cc.cond != nil {
 		// Wake up writeRequestBody if it is waiting on flow control.
 		cs.cc.cond.Broadcast()
 	}
-	return reqBody
 }
 
 func (cs *clientStream) abortRequestBodyWrite() {
 	cc := cs.cc
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
-	if cs.reqBody != nil && !cs.reqBodyClosed {
-		cs.reqBody.Close()
-		cs.reqBodyClosed = true
+	if cs.reqBody != nil && cs.reqBodyClosed == nil {
+		cs.closeReqBodyLocked()
 		cc.cond.Broadcast()
 	}
 }
 
+func (cs *clientStream) closeReqBodyLocked() {
+	if cs.reqBodyClosed != nil {
+		return
+	}
+	cs.reqBodyClosed = make(chan struct{})
+	reqBodyClosed := cs.reqBodyClosed
+	go func() {
+		cs.reqBody.Close()
+		close(reqBodyClosed)
+	}()
+}
+
 type stickyErrWriter struct {
 	conn    net.Conn
 	timeout time.Duration
@@ -771,12 +773,6 @@
 }
 
 func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
-	var reqBodiesToClose []io.ReadCloser
-	defer func() {
-		for _, reqBody := range reqBodiesToClose {
-			reqBody.Close()
-		}
-	}()
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 
@@ -793,10 +789,7 @@
 	last := f.LastStreamID
 	for streamID, cs := range cc.streams {
 		if streamID > last {
-			reqBody := cs.abortStreamLocked(errClientConnGotGoAway)
-			if reqBody != nil {
-				reqBodiesToClose = append(reqBodiesToClose, reqBody)
-			}
+			cs.abortStreamLocked(errClientConnGotGoAway)
 		}
 	}
 }
@@ -1049,19 +1042,11 @@
 func (cc *ClientConn) closeForError(err error) {
 	cc.mu.Lock()
 	cc.closed = true
-
-	var reqBodiesToClose []io.ReadCloser
 	for _, cs := range cc.streams {
-		reqBody := cs.abortStreamLocked(err)
-		if reqBody != nil {
-			reqBodiesToClose = append(reqBodiesToClose, reqBody)
-		}
+		cs.abortStreamLocked(err)
 	}
 	cc.cond.Broadcast()
 	cc.mu.Unlock()
-	for _, reqBody := range reqBodiesToClose {
-		reqBody.Close()
-	}
 	cc.closeConn()
 }
 
@@ -1458,11 +1443,19 @@
 	// and in multiple cases: server replies <=299 and >299
 	// while still writing request body
 	cc.mu.Lock()
+	mustCloseBody := false
+	if cs.reqBody != nil && cs.reqBodyClosed == nil {
+		mustCloseBody = true
+		cs.reqBodyClosed = make(chan struct{})
+	}
 	bodyClosed := cs.reqBodyClosed
-	cs.reqBodyClosed = true
 	cc.mu.Unlock()
-	if !bodyClosed && cs.reqBody != nil {
+	if mustCloseBody {
 		cs.reqBody.Close()
+		close(bodyClosed)
+	}
+	if bodyClosed != nil {
+		<-bodyClosed
 	}
 
 	if err != nil && cs.sentEndStream {
@@ -1642,7 +1635,7 @@
 		}
 		if err != nil {
 			cc.mu.Lock()
-			bodyClosed := cs.reqBodyClosed
+			bodyClosed := cs.reqBodyClosed != nil
 			cc.mu.Unlock()
 			switch {
 			case bodyClosed:
@@ -1737,7 +1730,7 @@
 		if cc.closed {
 			return 0, errClientConnClosed
 		}
-		if cs.reqBodyClosed {
+		if cs.reqBodyClosed != nil {
 			return 0, errStopReqBodyWrite
 		}
 		select {
@@ -2110,24 +2103,17 @@
 	}
 	cc.closed = true
 
-	var reqBodiesToClose []io.ReadCloser
 	for _, cs := range cc.streams {
 		select {
 		case <-cs.peerClosed:
 			// The server closed the stream before closing the conn,
 			// so no need to interrupt it.
 		default:
-			reqBody := cs.abortStreamLocked(err)
-			if reqBody != nil {
-				reqBodiesToClose = append(reqBodiesToClose, reqBody)
-			}
+			cs.abortStreamLocked(err)
 		}
 	}
 	cc.cond.Broadcast()
 	cc.mu.Unlock()
-	for _, reqBody := range reqBodiesToClose {
-		reqBody.Close()
-	}
 }
 
 // countReadFrameError calls Transport.CountError with a string