http2: block RoundTrip when the Transport hits MaxConcurrentStreams

Currently if the http2.Transport hits SettingsMaxConcurrentStreams for a
server, it just makes a new TCP connection and creates the stream on the
new connection. This CL updates that behavior to instead block RoundTrip
until a new stream is available.

I also fixed a second bug, which was necessary to make some tests pass:
Previously, a stream was removed from cc.streams only if either (a) we
received END_STREAM from the server, or (b) we received RST_STREAM from
the server. This CL removes a stream from cc.streams if the request was
cancelled (via ctx.Close, req.Cancel, or resp.Body.Close) before
receiving END_STREAM or RST_STREAM from the server.

Updates golang/go#13774
Updates golang/go#20985
Updates golang/go#21229

Change-Id: I660ffd724c4c513e0f1cc587b404bedb8aff80be
Reviewed-on: https://go-review.googlesource.com/53250
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 39a1a46..e0dfe9f 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -165,6 +165,7 @@
 	goAwayDebug     string                   // goAway frame's debug data, retained as a string
 	streams         map[uint32]*clientStream // client-initiated
 	nextStreamID    uint32
+	pendingRequests int                       // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
 	pings           map[[8]byte]chan struct{} // in flight ping data to notification channel
 	bw              *bufio.Writer
 	br              *bufio.Reader
@@ -217,35 +218,45 @@
 	resTrailer *http.Header // client's Response.Trailer
 }
 
-// awaitRequestCancel runs in its own goroutine and waits for the user
-// to cancel a RoundTrip request, its context to expire, or for the
-// request to be done (any way it might be removed from the cc.streams
-// map: peer reset, successful completion, TCP connection breakage,
-// etc)
-func (cs *clientStream) awaitRequestCancel(req *http.Request) {
+// awaitRequestCancel waits for the user to cancel a request or for the done
+// channel to be signaled. A non-nil error is returned only if the request was
+// canceled.
+func awaitRequestCancel(req *http.Request, done <-chan struct{}) error {
 	ctx := reqContext(req)
 	if req.Cancel == nil && ctx.Done() == nil {
-		return
+		return nil
 	}
 	select {
 	case <-req.Cancel:
-		cs.cancelStream()
-		cs.bufPipe.CloseWithError(errRequestCanceled)
+		return errRequestCanceled
 	case <-ctx.Done():
+		return ctx.Err()
+	case <-done:
+		return nil
+	}
+}
+
+// awaitRequestCancel waits for the user to cancel a request, its context to
+// expire, or for the request to be done (any way it might be removed from the
+// cc.streams map: peer reset, successful completion, TCP connection breakage,
+// etc). If the request is canceled, then cs will be canceled and closed.
+func (cs *clientStream) awaitRequestCancel(req *http.Request) {
+	if err := awaitRequestCancel(req, cs.done); err != nil {
 		cs.cancelStream()
-		cs.bufPipe.CloseWithError(ctx.Err())
-	case <-cs.done:
+		cs.bufPipe.CloseWithError(err)
 	}
 }
 
 func (cs *clientStream) cancelStream() {
-	cs.cc.mu.Lock()
+	cc := cs.cc
+	cc.mu.Lock()
 	didReset := cs.didReset
 	cs.didReset = true
-	cs.cc.mu.Unlock()
+	cc.mu.Unlock()
 
 	if !didReset {
-		cs.cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+		cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
+		cc.forgetStreamID(cs.ID)
 	}
 }
 
@@ -594,6 +605,8 @@
 	}
 }
 
+// CanTakeNewRequest reports whether the connection can take a new request,
+// meaning it has not been closed or received or sent a GOAWAY.
 func (cc *ClientConn) CanTakeNewRequest() bool {
 	cc.mu.Lock()
 	defer cc.mu.Unlock()
@@ -605,8 +618,7 @@
 		return false
 	}
 	return cc.goAway == nil && !cc.closed &&
-		int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
-		cc.nextStreamID < math.MaxInt32
+		int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
 }
 
 // onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -752,10 +764,9 @@
 	hasTrailers := trailers != ""
 
 	cc.mu.Lock()
-	cc.lastActive = time.Now()
-	if cc.closed || !cc.canTakeNewRequestLocked() {
+	if err := cc.awaitOpenSlotForRequest(req); err != nil {
 		cc.mu.Unlock()
-		return nil, errClientConnUnusable
+		return nil, err
 	}
 
 	body := req.Body
@@ -869,31 +880,31 @@
 		case re := <-readLoopResCh:
 			return handleReadLoopResponse(re)
 		case <-respHeaderTimer:
-			cc.forgetStreamID(cs.ID)
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
+			cc.forgetStreamID(cs.ID)
 			return nil, errTimeout
 		case <-ctx.Done():
-			cc.forgetStreamID(cs.ID)
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
+			cc.forgetStreamID(cs.ID)
 			return nil, ctx.Err()
 		case <-req.Cancel:
-			cc.forgetStreamID(cs.ID)
 			if !hasBody || bodyWritten {
 				cc.writeStreamReset(cs.ID, ErrCodeCancel, nil)
 			} else {
 				bodyWriter.cancel()
 				cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel)
 			}
+			cc.forgetStreamID(cs.ID)
 			return nil, errRequestCanceled
 		case <-cs.peerReset:
 			// processResetStream already removed the
@@ -920,6 +931,45 @@
 	}
 }
 
+// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams.
+// Must hold cc.mu.
+func (cc *ClientConn) awaitOpenSlotForRequest(req *http.Request) error {
+	var waitingForConn chan struct{}
+	var waitingForConnErr error // guarded by cc.mu
+	for {
+		cc.lastActive = time.Now()
+		if cc.closed || !cc.canTakeNewRequestLocked() {
+			return errClientConnUnusable
+		}
+		if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) {
+			if waitingForConn != nil {
+				close(waitingForConn)
+			}
+			return nil
+		}
+		// Unfortunately, we cannot wait on a condition variable and channel at
+		// the same time, so instead, we spin up a goroutine to check if the
+		// request is canceled while we wait for a slot to open in the connection.
+		if waitingForConn == nil {
+			waitingForConn = make(chan struct{})
+			go func() {
+				if err := awaitRequestCancel(req, waitingForConn); err != nil {
+					cc.mu.Lock()
+					waitingForConnErr = err
+					cc.cond.Broadcast()
+					cc.mu.Unlock()
+				}
+			}()
+		}
+		cc.pendingRequests++
+		cc.cond.Wait()
+		cc.pendingRequests--
+		if waitingForConnErr != nil {
+			return waitingForConnErr
+		}
+	}
+}
+
 // requires cc.wmu be held
 func (cc *ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error {
 	first := true // first frame written (HEADERS is first, then CONTINUATION)
@@ -1279,7 +1329,9 @@
 			cc.idleTimer.Reset(cc.idleTimeout)
 		}
 		close(cs.done)
-		cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl
+		// Wake up checkResetOrDone via clientStream.awaitFlowControl and
+		// wake up RoundTrip if there is a pending request.
+		cc.cond.Broadcast()
 	}
 	return cs
 }
@@ -1378,8 +1430,9 @@
 			cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
 		}
 		if se, ok := err.(StreamError); ok {
-			if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil {
+			if cs := cc.streamByID(se.StreamID, false); cs != nil {
 				cs.cc.writeStreamReset(cs.ID, se.Code, err)
+				cs.cc.forgetStreamID(cs.ID)
 				if se.Cause == nil {
 					se.Cause = cc.fr.errDetail
 				}
@@ -1701,6 +1754,7 @@
 	}
 
 	cs.bufPipe.BreakWithError(errClosedResponseBody)
+	cc.forgetStreamID(cs.ID)
 	return nil
 }
 
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 0e7b801..ac4661f 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -685,7 +685,7 @@
 	return ln
 }
 
-func (ct *clientTester) greet() {
+func (ct *clientTester) greet(settings ...Setting) {
 	buf := make([]byte, len(ClientPreface))
 	_, err := io.ReadFull(ct.sc, buf)
 	if err != nil {
@@ -699,7 +699,7 @@
 		ct.t.Fatalf("Wanted client settings frame; got %v", f)
 		_ = sf // stash it away?
 	}
-	if err := ct.fr.WriteSettings(); err != nil {
+	if err := ct.fr.WriteSettings(settings...); err != nil {
 		ct.t.Fatal(err)
 	}
 	if err := ct.fr.WriteSettingsAck(); err != nil {
@@ -3036,6 +3036,175 @@
 	ct.run()
 }
 
+func TestTransportRequestsStallAtServerLimit(t *testing.T) {
+	const maxConcurrent = 2
+
+	greet := make(chan struct{})      // server sends initial SETTINGS frame
+	gotRequest := make(chan struct{}) // server received a request
+	clientDone := make(chan struct{})
+
+	// Collect errors from goroutines.
+	var wg sync.WaitGroup
+	errs := make(chan error, 100)
+	defer func() {
+		wg.Wait()
+		close(errs)
+		for err := range errs {
+			t.Error(err)
+		}
+	}()
+
+	// We will send maxConcurrent+2 requests. This checker goroutine waits for the
+	// following stages:
+	//   1. The first maxConcurrent requests are received by the server.
+	//   2. The client will cancel the next request
+	//   3. The server is unblocked so it can service the first maxConcurrent requests
+	//   4. The client will send the final request
+	wg.Add(1)
+	unblockClient := make(chan struct{})
+	clientRequestCancelled := make(chan struct{})
+	unblockServer := make(chan struct{})
+	go func() {
+		defer wg.Done()
+		// Stage 1.
+		for k := 0; k < maxConcurrent; k++ {
+			<-gotRequest
+		}
+		// Stage 2.
+		close(unblockClient)
+		<-clientRequestCancelled
+		// Stage 3: give some time for the final RoundTrip call to be scheduled and
+		// verify that the final request is not sent.
+		time.Sleep(50 * time.Millisecond)
+		select {
+		case <-gotRequest:
+			errs <- errors.New("last request did not stall")
+			close(unblockServer)
+			return
+		default:
+		}
+		close(unblockServer)
+		// Stage 4.
+		<-gotRequest
+	}()
+
+	ct := newClientTester(t)
+	ct.client = func() error {
+		var wg sync.WaitGroup
+		defer func() {
+			wg.Wait()
+			close(clientDone)
+			ct.cc.(*net.TCPConn).CloseWrite()
+		}()
+		for k := 0; k < maxConcurrent+2; k++ {
+			wg.Add(1)
+			go func(k int) {
+				defer wg.Done()
+				// Don't send the second request until after receiving SETTINGS from the server
+				// to avoid a race where we use the default SettingMaxConcurrentStreams, which
+				// is much larger than maxConcurrent. We have to send the first request before
+				// waiting because the first request triggers the dial and greet.
+				if k > 0 {
+					<-greet
+				}
+				// Block until maxConcurrent requests are sent before sending any more.
+				if k >= maxConcurrent {
+					<-unblockClient
+				}
+				req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
+				if k == maxConcurrent {
+					// This request will be canceled.
+					cancel := make(chan struct{})
+					req.Cancel = cancel
+					close(cancel)
+					_, err := ct.tr.RoundTrip(req)
+					close(clientRequestCancelled)
+					if err == nil {
+						errs <- fmt.Errorf("RoundTrip(%d) should have failed due to cancel", k)
+						return
+					}
+				} else {
+					resp, err := ct.tr.RoundTrip(req)
+					if err != nil {
+						errs <- fmt.Errorf("RoundTrip(%d): %v", k, err)
+						return
+					}
+					ioutil.ReadAll(resp.Body)
+					resp.Body.Close()
+					if resp.StatusCode != 204 {
+						errs <- fmt.Errorf("Status = %v; want 204", resp.StatusCode)
+						return
+					}
+				}
+			}(k)
+		}
+		return nil
+	}
+
+	ct.server = func() error {
+		var wg sync.WaitGroup
+		defer wg.Wait()
+
+		ct.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+		// Server write loop.
+		var buf bytes.Buffer
+		enc := hpack.NewEncoder(&buf)
+		writeResp := make(chan uint32, maxConcurrent+1)
+
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			<-unblockServer
+			for id := range writeResp {
+				buf.Reset()
+				enc.WriteField(hpack.HeaderField{Name: ":status", Value: "204"})
+				ct.fr.WriteHeaders(HeadersFrameParam{
+					StreamID:      id,
+					EndHeaders:    true,
+					EndStream:     true,
+					BlockFragment: buf.Bytes(),
+				})
+			}
+		}()
+
+		// Server read loop.
+		var nreq int
+		for {
+			f, err := ct.fr.ReadFrame()
+			if err != nil {
+				select {
+				case <-clientDone:
+					// If the client's done, it will have reported any errors on its side.
+					return nil
+				default:
+					return err
+				}
+			}
+			switch f := f.(type) {
+			case *WindowUpdateFrame:
+			case *SettingsFrame:
+				// Wait for the client SETTINGS ack until ending the greet.
+				close(greet)
+			case *HeadersFrame:
+				if !f.HeadersEnded() {
+					return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
+				}
+				gotRequest <- struct{}{}
+				nreq++
+				writeResp <- f.StreamID
+				if nreq == maxConcurrent+1 {
+					close(writeResp)
+				}
+			default:
+				return fmt.Errorf("Unexpected client frame %v", f)
+			}
+		}
+	}
+
+	ct.run()
+}
+
 func TestAuthorityAddr(t *testing.T) {
 	tests := []struct {
 		scheme, authority string