http2: fix Transport connection pool TOCTOU max concurrent stream bug

Change-Id: I3e02072403f2f40ade4ef931058bbb5892776754
Reviewed-on: https://go-review.googlesource.com/c/net/+/352469
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 652bc11..8fd95bb 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -16,6 +16,12 @@
 
 // ClientConnPool manages a pool of HTTP/2 client connections.
 type ClientConnPool interface {
+	// GetClientConn returns a specific HTTP/2 connection (usually
+	// a TLS-TCP connection) to an HTTP/2 server. On success, the
+	// returned ClientConn accounts for the upcoming RoundTrip
+	// call, so the caller should not omit it. If the caller needs
+	// to, ClientConn.RoundTrip can be called with a bogus
+	// new(http.Request) to release the stream reservation.
 	GetClientConn(req *http.Request, addr string) (*ClientConn, error)
 	MarkDead(*ClientConn)
 }
@@ -61,7 +67,7 @@
 // during the back-and-forth between net/http and x/net/http2 (when the
 // net/http.Transport is upgraded to also speak http2), as well as support
 // the case where x/net/http2 is being used directly.
-func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
+func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool {
 	// If our Transport wasn't made via ConfigureTransport, always
 	// trace the GetConn hook if provided, because that means the
 	// http2 package is being used directly and it's the one
@@ -72,7 +78,9 @@
 	// Otherwise, only use the GetConn hook if this connection has
 	// been used previously for other requests. For fresh
 	// connections, the net/http package does the dialing.
-	return !st.freshConn
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	return cc.nextStreamID == 1
 }
 
 func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
@@ -89,8 +97,8 @@
 	for {
 		p.mu.Lock()
 		for _, cc := range p.conns[addr] {
-			if st := cc.idleState(); st.canTakeNewRequest {
-				if p.shouldTraceGetConn(st) {
+			if cc.ReserveNewRequest() {
+				if p.shouldTraceGetConn(cc) {
 					traceGetConn(req, addr)
 				}
 				p.mu.Unlock()
@@ -108,7 +116,13 @@
 		if shouldRetryDial(call, req) {
 			continue
 		}
-		return call.res, call.err
+		cc, err := call.res, call.err
+		if err != nil {
+			return nil, err
+		}
+		if cc.ReserveNewRequest() {
+			return cc, nil
+		}
 	}
 }
 
diff --git a/http2/transport.go b/http2/transport.go
index a7f113b..1fb565f 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -267,6 +267,7 @@
 	goAway          *GoAwayFrame             // if non-nil, the GoAwayFrame we received
 	goAwayDebug     string                   // goAway frame's debug data, retained as a string
 	streams         map[uint32]*clientStream // client-initiated
+	streamsReserved int                      // incr by ReserveNewRequest; decr on RoundTrip
 	nextStreamID    uint32
 	pendingRequests int                       // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
 	pings           map[[8]byte]chan struct{} // in flight ping data to notification channel
@@ -784,12 +785,28 @@
 
 // CanTakeNewRequest reports whether the connection can take a new request,
 // meaning it has not been closed or received or sent a GOAWAY.
+//
+// If the caller is going to immediately make a new request on this
+// connection, use ReserveNewRequest instead.
 func (cc *ClientConn) CanTakeNewRequest() bool {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
 	return cc.canTakeNewRequestLocked()
 }
 
+// ReserveNewRequest is like CanTakeNewRequest but also reserves a
+// concurrent stream in cc. The reservation is decremented on the
+// next call to RoundTrip.
+func (cc *ClientConn) ReserveNewRequest() bool {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	if st := cc.idleStateLocked(); !st.canTakeNewRequest {
+		return false
+	}
+	cc.streamsReserved++
+	return true
+}
+
 // clientConnIdleState describes the suitability of a client
 // connection to initiate a new RoundTrip request.
 type clientConnIdleState struct {
@@ -815,7 +832,7 @@
 		// writing it.
 		maxConcurrentOkay = true
 	} else {
-		maxConcurrentOkay = int64(len(cc.streams)+1) <= int64(cc.maxConcurrentStreams)
+		maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
 	}
 
 	st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
@@ -1038,6 +1055,18 @@
 	return -1
 }
 
+func (cc *ClientConn) decrStreamReservations() {
+	cc.mu.Lock()
+	defer cc.mu.Unlock()
+	cc.decrStreamReservationsLocked()
+}
+
+func (cc *ClientConn) decrStreamReservationsLocked() {
+	if cc.streamsReserved > 0 {
+		cc.streamsReserved--
+	}
+}
+
 func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
 	resp, _, err := cc.roundTrip(req)
 	return resp, err
@@ -1046,6 +1075,7 @@
 func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
 	ctx := req.Context()
 	if err := checkConnHeaders(req); err != nil {
+		cc.decrStreamReservations()
 		return nil, false, err
 	}
 	if cc.idleTimer != nil {
@@ -1054,6 +1084,7 @@
 
 	trailers, err := commaSeparatedTrailers(req)
 	if err != nil {
+		cc.decrStreamReservations()
 		return nil, false, err
 	}
 	hasTrailers := trailers != ""
@@ -1067,8 +1098,10 @@
 	select {
 	case cc.reqHeaderMu <- struct{}{}:
 	case <-req.Cancel:
+		cc.decrStreamReservations()
 		return nil, false, errRequestCanceled
 	case <-ctx.Done():
+		cc.decrStreamReservations()
 		return nil, false, ctx.Err()
 	}
 	reqHeaderMuNeedsUnlock := true
@@ -1079,6 +1112,11 @@
 	}()
 
 	cc.mu.Lock()
+	cc.decrStreamReservationsLocked()
+	if req.URL == nil {
+		cc.mu.Unlock()
+		return nil, false, errNilRequestURL
+	}
 	if err := cc.awaitOpenSlotForRequest(req); err != nil {
 		cc.mu.Unlock()
 		return nil, false, err
@@ -1531,9 +1569,14 @@
 	}
 }
 
+var errNilRequestURL = errors.New("http2: Request.URI is nil")
+
 // requires cc.wmu be held.
 func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
 	cc.hbuf.Reset()
+	if req.URL == nil {
+		return nil, errNilRequestURL
+	}
 
 	host := req.Host
 	if host == "" {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index dd0860d..3eb133f 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5422,3 +5422,42 @@
 	}
 	ct.run()
 }
+
+func TestClientConnReservations(t *testing.T) {
+	cc := &ClientConn{
+		reqHeaderMu:          make(chan struct{}, 1),
+		streams:              make(map[uint32]*clientStream),
+		maxConcurrentStreams: initialMaxConcurrentStreams,
+		t:                    &Transport{},
+	}
+	n := 0
+	for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+		n++
+	}
+	if n != initialMaxConcurrentStreams {
+		t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
+	}
+	if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
+		t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
+	}
+	n2 := 0
+	for n2 <= 5 && cc.ReserveNewRequest() {
+		n2++
+	}
+	if n2 != 1 {
+		t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
+	}
+
+	// Use up all the reservations
+	for i := 0; i < n; i++ {
+		cc.RoundTrip(new(http.Request))
+	}
+
+	n2 = 0
+	for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+		n2++
+	}
+	if n2 != n {
+		t.Errorf("after reset, reservations = %v; want %v", n2, n)
+	}
+}