http2: fix Transport connection pool TOCTOU max concurrent stream bug
Change-Id: I3e02072403f2f40ade4ef931058bbb5892776754
Reviewed-on: https://go-review.googlesource.com/c/net/+/352469
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 652bc11..8fd95bb 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -16,6 +16,12 @@
// ClientConnPool manages a pool of HTTP/2 client connections.
type ClientConnPool interface {
+ // GetClientConn returns a specific HTTP/2 connection (usually
+ // a TLS-TCP connection) to an HTTP/2 server. On success, the
+ // returned ClientConn accounts for the upcoming RoundTrip
+ // call, so the caller should not omit it. If the caller needs
+ // to, ClientConn.RoundTrip can be called with a bogus
+ // new(http.Request) to release the stream reservation.
GetClientConn(req *http.Request, addr string) (*ClientConn, error)
MarkDead(*ClientConn)
}
@@ -61,7 +67,7 @@
// during the back-and-forth between net/http and x/net/http2 (when the
// net/http.Transport is upgraded to also speak http2), as well as support
// the case where x/net/http2 is being used directly.
-func (p *clientConnPool) shouldTraceGetConn(st clientConnIdleState) bool {
+func (p *clientConnPool) shouldTraceGetConn(cc *ClientConn) bool {
// If our Transport wasn't made via ConfigureTransport, always
// trace the GetConn hook if provided, because that means the
// http2 package is being used directly and it's the one
@@ -72,7 +78,9 @@
// Otherwise, only use the GetConn hook if this connection has
// been used previously for other requests. For fresh
// connections, the net/http package does the dialing.
- return !st.freshConn
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.nextStreamID == 1
}
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*ClientConn, error) {
@@ -89,8 +97,8 @@
for {
p.mu.Lock()
for _, cc := range p.conns[addr] {
- if st := cc.idleState(); st.canTakeNewRequest {
- if p.shouldTraceGetConn(st) {
+ if cc.ReserveNewRequest() {
+ if p.shouldTraceGetConn(cc) {
traceGetConn(req, addr)
}
p.mu.Unlock()
@@ -108,7 +116,13 @@
if shouldRetryDial(call, req) {
continue
}
- return call.res, call.err
+ cc, err := call.res, call.err
+ if err != nil {
+ return nil, err
+ }
+ if cc.ReserveNewRequest() {
+ return cc, nil
+ }
}
}
diff --git a/http2/transport.go b/http2/transport.go
index a7f113b..1fb565f 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -267,6 +267,7 @@
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*clientStream // client-initiated
+ streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
nextStreamID uint32
pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
@@ -784,12 +785,28 @@
// CanTakeNewRequest reports whether the connection can take a new request,
// meaning it has not been closed or received or sent a GOAWAY.
+//
+// If the caller is going to immediately make a new request on this
+// connection, use ReserveNewRequest instead.
func (cc *ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.canTakeNewRequestLocked()
}
+// ReserveNewRequest is like CanTakeNewRequest but also reserves a
+// concurrent stream in cc. The reservation is decremented on the
+// next call to RoundTrip.
+func (cc *ClientConn) ReserveNewRequest() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if st := cc.idleStateLocked(); !st.canTakeNewRequest {
+ return false
+ }
+ cc.streamsReserved++
+ return true
+}
+
// clientConnIdleState describes the suitability of a client
// connection to initiate a new RoundTrip request.
type clientConnIdleState struct {
@@ -815,7 +832,7 @@
// writing it.
maxConcurrentOkay = true
} else {
- maxConcurrentOkay = int64(len(cc.streams)+1) <= int64(cc.maxConcurrentStreams)
+ maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
@@ -1038,6 +1055,18 @@
return -1
}
+func (cc *ClientConn) decrStreamReservations() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.decrStreamReservationsLocked()
+}
+
+func (cc *ClientConn) decrStreamReservationsLocked() {
+ if cc.streamsReserved > 0 {
+ cc.streamsReserved--
+ }
+}
+
func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
resp, _, err := cc.roundTrip(req)
return resp, err
@@ -1046,6 +1075,7 @@
func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAfterReqBodyWrite bool, err error) {
ctx := req.Context()
if err := checkConnHeaders(req); err != nil {
+ cc.decrStreamReservations()
return nil, false, err
}
if cc.idleTimer != nil {
@@ -1054,6 +1084,7 @@
trailers, err := commaSeparatedTrailers(req)
if err != nil {
+ cc.decrStreamReservations()
return nil, false, err
}
hasTrailers := trailers != ""
@@ -1067,8 +1098,10 @@
select {
case cc.reqHeaderMu <- struct{}{}:
case <-req.Cancel:
+ cc.decrStreamReservations()
return nil, false, errRequestCanceled
case <-ctx.Done():
+ cc.decrStreamReservations()
return nil, false, ctx.Err()
}
reqHeaderMuNeedsUnlock := true
@@ -1079,6 +1112,11 @@
}()
cc.mu.Lock()
+ cc.decrStreamReservationsLocked()
+ if req.URL == nil {
+ cc.mu.Unlock()
+ return nil, false, errNilRequestURL
+ }
if err := cc.awaitOpenSlotForRequest(req); err != nil {
cc.mu.Unlock()
return nil, false, err
@@ -1531,9 +1569,14 @@
}
}
+var errNilRequestURL = errors.New("http2: Request.URI is nil")
+
// requires cc.wmu be held.
func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
cc.hbuf.Reset()
+ if req.URL == nil {
+ return nil, errNilRequestURL
+ }
host := req.Host
if host == "" {
diff --git a/http2/transport_test.go b/http2/transport_test.go
index dd0860d..3eb133f 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5422,3 +5422,42 @@
}
ct.run()
}
+
+func TestClientConnReservations(t *testing.T) {
+ cc := &ClientConn{
+ reqHeaderMu: make(chan struct{}, 1),
+ streams: make(map[uint32]*clientStream),
+ maxConcurrentStreams: initialMaxConcurrentStreams,
+ t: &Transport{},
+ }
+ n := 0
+ for n <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+ n++
+ }
+ if n != initialMaxConcurrentStreams {
+ t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
+ }
+ if _, err := cc.RoundTrip(new(http.Request)); !errors.Is(err, errNilRequestURL) {
+ t.Fatalf("RoundTrip error = %v; want errNilRequestURL", err)
+ }
+ n2 := 0
+ for n2 <= 5 && cc.ReserveNewRequest() {
+ n2++
+ }
+ if n2 != 1 {
+ t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
+ }
+
+ // Use up all the reservations
+ for i := 0; i < n; i++ {
+ cc.RoundTrip(new(http.Request))
+ }
+
+ n2 = 0
+ for n2 <= initialMaxConcurrentStreams && cc.ReserveNewRequest() {
+ n2++
+ }
+ if n2 != n {
+ t.Errorf("after reset, reservations = %v; want %v", n2, n)
+ }
+}