http2: block RoundTrip when the Transport hits MaxConcurrentStreams
Currently if the http2.Transport hits SettingsMaxConcurrentStreams for a
server, it just makes a new TCP connection and creates the stream on the
new connection. This CL updates that behavior to instead block RoundTrip
until a new stream is available.
I also fixed a second bug, which was necessary to make some tests pass:
Previously, a stream was removed from cc.streams only if either (a) we
received END_STREAM from the server, or (b) we received RST_STREAM from
the server. This CL removes a stream from cc.streams if the request was
cancelled (via ctx.Close, req.Cancel, or resp.Body.Close) before
receiving END_STREAM or RST_STREAM from the server.
Updates golang/go#13774
Updates golang/go#20985
Updates golang/go#21229
Change-Id: I660ffd724c4c513e0f1cc587b404bedb8aff80be
Reviewed-on: https://go-review.googlesource.com/53250
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 39a1a46..e0dfe9f 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -165,6 +165,7 @@
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*clientStream // client-initiated
nextStreamID uint32
+ pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
pings map[[8]byte]chan struct{} // in flight ping data to notification channel
bw *bufio.Writer
br *bufio.Reader
@@ -217,35 +218,45 @@
resTrailer *http.Header // client's Response.Trailer
}
-// awaitRequestCancel runs in its own goroutine and waits for the user
-// to cancel a RoundTrip request, its context to expire, or for the
-// request to be done (any way it might be removed from the cc.streams
-// map: peer reset, successful completion, TCP connection breakage,
-// etc)
-func (cs *clientStream) awaitRequestCancel(req *http.Request) {
+// awaitRequestCancel waits for the user to cancel a request or for the done
+// channel to be signaled. A non-nil error is returned only if the request was
+// canceled.
+func awaitRequestCancel(req *http.Request, done <-chan struct{}) error {
ctx := reqContext(req)
if req.Cancel == nil && ctx.Done() == nil {
- return
+ return nil
}
select {
case <-req.Cancel:
- cs.cancelStream()
- cs.bufPipe.CloseWithError(errRequestCanceled)
+ return errRequestCanceled
case <-ctx.Done():
+ return ctx.Err()
+ case <-done:
+ return nil
+ }
+}
+
+// awaitRequestCancel waits for the user to cancel a request, its context to
+// expire, or for the request to be done (any way it might be removed from the
+// cc.streams map: peer reset, successful completion, TCP connection breakage,
+// etc). If the request is canceled, then cs will be canceled and closed.
+func (cs *clientStream) awaitRequestCancel(req *http.Request) {
+ if err := awaitRequestCancel(req, cs.done); err != nil {
cs.cancelStream()
- cs.bufPipe.CloseWithError(ctx.Err())
- case <-cs.done:
+ cs.bufPipe.CloseWithError(err)
}
}
func (cs *clientStream) cancelStream() {
- cs.cc.mu.Lock()
+ cc := cs.cc
+ cc.mu.Lock()
didReset := cs.didReset
cs.didReset = true
- cs.cc.mu.Unlock()
+ cc.mu.Unlock()
if !didReset {
- cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+ cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+ cc.forgetStreamID(cs.ID)
}
}
@@ -594,6 +605,8 @@
}
}
+// CanTakeNewRequest reports whether the connection can take a new request,
+// meaning it has not been closed or received or sent a GOAWAY.
func (cc *ClientConn) CanTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
@@ -605,8 +618,7 @@
return false
}
return cc.goAway == nil && !cc.closed &&
- int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
- cc.nextStreamID < math.MaxInt32
+ int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -752,10 +764,9 @@
hasTrailers := trailers != ""
cc.mu.Lock()
- cc.lastActive = time.Now()
- if cc.closed || !cc.canTakeNewRequestLocked() {
+ if err := cc.awaitOpenSlotForRequest(req); err != nil {
cc.mu.Unlock()
- return nil, errClientConnUnusable
+ return nil, err
}
body := req.Body
@@ -869,31 +880,31 @@
case re := <-readLoopResCh:
return handleReadLoopResponse(re)
case <-respHeaderTimer:
- cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
+ cc.forgetStreamID(cs.ID)
return nil, errTimeout
case <-ctx.Done():
- cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
+ cc.forgetStreamID(cs.ID)
return nil, ctx.Err()
case <-req.Cancel:
- cc.forgetStreamID(cs.ID)
if !hasBody || bodyWritten {
cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
} else {
bodyWriter.cancel()
cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
}
+ cc.forgetStreamID(cs.ID)
return nil, errRequestCanceled
case <-cs.peerReset:
// processResetStream already removed the
@@ -920,6 +931,45 @@
}
}
+// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams.
+// Must hold cc.mu.
+func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error {
+ var waitingForConn chan struct{}
+ var waitingForConnErr error // guarded by cc.mu
+ for {
+ cc.lastActive = time.Now()
+ if cc.closed || !cc.canTakeNewRequestLocked() {
+ return errClientConnUnusable
+ }
+ if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) {
+ if waitingForConn != nil {
+ close(waitingForConn)
+ }
+ return nil
+ }
+ // Unfortunately, we cannot wait on a condition variable and channel at
+ // the same time, so instead, we spin up a goroutine to check if the
+ // request is canceled while we wait for a slot to open in the connection.
+ if waitingForConn == nil {
+ waitingForConn = make(chan struct{})
+ go func() {
+ if err := awaitRequestCancel(req, waitingForConn); err != nil {
+ cc.mu.Lock()
+ waitingForConnErr = err
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ }
+ }()
+ }
+ cc.pendingRequests++
+ cc.cond.Wait()
+ cc.pendingRequests--
+ if waitingForConnErr != nil {
+ return waitingForConnErr
+ }
+ }
+}
+
// requires cc.wmu be held
func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error {
first := true // first frame written (HEADERS is first, then CONTINUATION)
@@ -1279,7 +1329,9 @@
cc.idleTimer.Reset(cc.idleTimeout)
}
close(cs.done)
- cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
+ // Wake up checkResetOrDone via clientStream.awaitFlowControl and
+ // wake up RoundTrip if there is a pending request.
+ cc.cond.Broadcast()
}
return cs
}
@@ -1378,8 +1430,9 @@
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
if se, ok := err.(StreamError); ok {
- if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
+ if cs := cc.streamByID(se.StreamID, false); cs != nil {
cs.cc.writeStreamReset(cs.ID, se.Code, err)
+ cs.cc.forgetStreamID(cs.ID)
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
@@ -1701,6 +1754,7 @@
}
cs.bufPipe.BreakWithError(errClosedResponseBody)
+ cc.forgetStreamID(cs.ID)
return nil
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 0e7b801..ac4661f 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -685,7 +685,7 @@
return ln
}
-func (ct *clientTester) greet() {
+func (ct *clientTester) greet(settings ...Setting) {
buf := make([]byte, len(ClientPreface))
_, err := io.ReadFull(ct.sc, buf)
if err != nil {
@@ -699,7 +699,7 @@
ct.t.Fatalf("Wanted client settings frame; got %v", f)
_ = sf // stash it away?
}
- if err := ct.fr.WriteSettings(); err != nil {
+ if err := ct.fr.WriteSettings(settings...); err != nil {
ct.t.Fatal(err)
}
if err := ct.fr.WriteSettingsAck(); err != nil {
@@ -3036,6 +3036,175 @@
ct.run()
}
+func TestTransportRequestsStallAtServerLimit(t *testing.T) {
+ const maxConcurrent = 2
+
+ greet := make(chan struct{}) // server sends initial SETTINGS frame
+ gotRequest := make(chan struct{}) // server received a request
+ clientDone := make(chan struct{})
+
+ // Collect errors from goroutines.
+ var wg sync.WaitGroup
+ errs := make(chan error, 100)
+ defer func() {
+ wg.Wait()
+ close(errs)
+ for err := range errs {
+ t.Error(err)
+ }
+ }()
+
+ // We will send maxConcurrent+2 requests. This checker goroutine waits for the
+ // following stages:
+ // 1. The first maxConcurrent requests are received by the server.
+ // 2. The client will cancel the next request
+ // 3. The server is unblocked so it can service the first maxConcurrent requests
+ // 4. The client will send the final request
+ wg.Add(1)
+ unblockClient := make(chan struct{})
+ clientRequestCancelled := make(chan struct{})
+ unblockServer := make(chan struct{})
+ go func() {
+ defer wg.Done()
+ // Stage 1.
+ for k := 0; k < maxConcurrent; k++ {
+ <-gotRequest
+ }
+ // Stage 2.
+ close(unblockClient)
+ <-clientRequestCancelled
+ // Stage 3: give some time for the final RoundTrip call to be scheduled and
+ // verify that the final request is not sent.
+ time.Sleep(50 * time.Millisecond)
+ select {
+ case <-gotRequest:
+ errs <- errors.New("last request did not stall")
+ close(unblockServer)
+ return
+ default:
+ }
+ close(unblockServer)
+ // Stage 4.
+ <-gotRequest
+ }()
+
+ ct := newClientTester(t)
+ ct.client = func() error {
+ var wg sync.WaitGroup
+ defer func() {
+ wg.Wait()
+ close(clientDone)
+ ct.cc.(*net.TCPConn).CloseWrite()
+ }()
+ for k := 0; k < maxConcurrent+2; k++ {
+ wg.Add(1)
+ go func(k int) {
+ defer wg.Done()
+ // Don't send the second request until after receiving SETTINGS from the server
+ // to avoid a race where we use the default SettingMaxConcurrentStreams, which
+ // is much larger than maxConcurrent. We have to send the first request before
+ // waiting because the first request triggers the dial and greet.
+ if k > 0 {
+ <-greet
+ }
+ // Block until maxConcurrent requests are sent before sending any more.
+ if k >= maxConcurrent {
+ <-unblockClient
+ }
+ req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+ if k == maxConcurrent {
+ // This request will be canceled.
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+ close(cancel)
+ _, err := ct.tr.RoundTrip(req)
+ close(clientRequestCancelled)
+ if err == nil {
+ errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
+ return
+ }
+ } else {
+ resp, err := ct.tr.RoundTrip(req)
+ if err != nil {
+ errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
+ return
+ }
+ ioutil.ReadAll(resp.Body)
+ resp.Body.Close()
+ if resp.StatusCode != 204 {
+ errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+ return
+ }
+ }
+ }(k)
+ }
+ return nil
+ }
+
+ ct.server = func() error {
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+ // Server write loop.
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ writeResp := make(chan uint32, maxConcurrent+1)
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ <-unblockServer
+ for id := range writeResp {
+ buf.Reset()
+ enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+ ct.fr.WriteHeaders(HeadersFrameParam{
+ StreamID: id,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: buf.Bytes(),
+ })
+ }
+ }()
+
+ // Server read loop.
+ var nreq int
+ for {
+ f, err := ct.fr.ReadFrame()
+ if err != nil {
+ select {
+ case <-clientDone:
+ // If the client's done, it will have reported any errors on its side.
+ return nil
+ default:
+ return err
+ }
+ }
+ switch f := f.(type) {
+ case *WindowUpdateFrame:
+ case *SettingsFrame:
+ // Wait for the client SETTINGS ack until ending the greet.
+ close(greet)
+ case *HeadersFrame:
+ if !f.HeadersEnded() {
+ return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+ }
+ gotRequest <- struct{}{}
+ nreq++
+ writeResp <- f.StreamID
+ if nreq == maxConcurrent+1 {
+ close(writeResp)
+ }
+ default:
+ return fmt.Errorf("Unexpected client frame %v", f)
+ }
+ }
+ }
+
+ ct.run()
+}
+
func TestAuthorityAddr(t *testing.T) {
tests := []struct {
scheme, authority string