[internal-branch.go1.21-vendor] http2: limit maximum handler goroutines to MaxConcurrentStreams

When the peer opens a new stream while we have MaxConcurrentStreams
handler goroutines running, defer starting a handler until one
of the existing handlers exits.

For golang/go#63417.
For golang/go#63427.
For CVE-2023-39325.

Change-Id: If0531e177b125700f3e24c5ebd24b1023098fa6d
Reviewed-on: https://team-review.git.corp.google.com/c/golang/go-private/+/2047391
Reviewed-by: Tatiana Bradley <tatianabradley@google.com>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Ian Cottrell <iancottrell@google.com>
Reviewed-on: https://go-review.googlesource.com/c/net/+/534218
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Michael Pratt <mpratt@google.com>
diff --git a/http2/server.go b/http2/server.go
index 033b6e6..4561e3c 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -581,9 +581,11 @@
 	advMaxStreams               uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
 	curClientStreams            uint32 // number of open streams initiated by the client
 	curPushedStreams            uint32 // number of open streams initiated by server push
+	curHandlers                 uint32 // number of running handler goroutines
 	maxClientStreamID           uint32 // max ever seen from client (odd), or 0 if there have been no client requests
 	maxPushPromiseID            uint32 // ID of the last push promise (even), or 0 if there have been no pushes
 	streams                     map[uint32]*stream
+	unstartedHandlers           []unstartedHandler
 	initialStreamSendWindowSize int32
 	maxFrameSize                int32
 	peerMaxHeaderListSize       uint32            // zero means unknown (default)
@@ -981,6 +983,8 @@
 					return
 				case gracefulShutdownMsg:
 					sc.startGracefulShutdownInternal()
+				case handlerDoneMsg:
+					sc.handlerDone()
 				default:
 					panic("unknown timer")
 				}
@@ -1028,6 +1032,7 @@
 	idleTimerMsg        = new(serverMessage)
 	shutdownTimerMsg    = new(serverMessage)
 	gracefulShutdownMsg = new(serverMessage)
+	handlerDoneMsg      = new(serverMessage)
 )
 
 func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@@ -2025,8 +2030,7 @@
 		}
 	}
 
-	go sc.runHandler(rw, req, handler)
-	return nil
+	return sc.scheduleHandler(id, rw, req, handler)
 }
 
 func (sc *serverConn) upgradeRequest(req *http.Request) {
@@ -2046,6 +2050,10 @@
 		sc.conn.SetReadDeadline(time.Time{})
 	}
 
+	// This is the first request on the connection,
+	// so start the handler directly rather than going
+	// through scheduleHandler.
+	sc.curHandlers++
 	go sc.runHandler(rw, req, sc.handler.ServeHTTP)
 }
 
@@ -2286,8 +2294,62 @@
 	return &responseWriter{rws: rws}
 }
 
+type unstartedHandler struct {
+	streamID uint32
+	rw       *responseWriter
+	req      *http.Request
+	handler  func(http.ResponseWriter, *http.Request)
+}
+
+// scheduleHandler starts a handler goroutine,
+// or schedules one to start as soon as an existing handler finishes.
+func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error {
+	sc.serveG.check()
+	maxHandlers := sc.advMaxStreams
+	if sc.curHandlers < maxHandlers {
+		sc.curHandlers++
+		go sc.runHandler(rw, req, handler)
+		return nil
+	}
+	if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
+		return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
+	}
+	sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
+		streamID: streamID,
+		rw:       rw,
+		req:      req,
+		handler:  handler,
+	})
+	return nil
+}
+
+func (sc *serverConn) handlerDone() {
+	sc.serveG.check()
+	sc.curHandlers--
+	i := 0
+	maxHandlers := sc.advMaxStreams
+	for ; i < len(sc.unstartedHandlers); i++ {
+		u := sc.unstartedHandlers[i]
+		if sc.streams[u.streamID] == nil {
+			// This stream was reset before its goroutine had a chance to start.
+			continue
+		}
+		if sc.curHandlers >= maxHandlers {
+			break
+		}
+		sc.curHandlers++
+		go sc.runHandler(u.rw, u.req, u.handler)
+		sc.unstartedHandlers[i] = unstartedHandler{} // don't retain references
+	}
+	sc.unstartedHandlers = sc.unstartedHandlers[i:]
+	if len(sc.unstartedHandlers) == 0 {
+		sc.unstartedHandlers = nil
+	}
+}
+
 // Run on its own goroutine.
 func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
+	defer sc.sendServeMsg(handlerDoneMsg)
 	didPanic := true
 	defer func() {
 		rw.rws.stream.cancelCtx()
diff --git a/http2/server_test.go b/http2/server_test.go
index cd73291..28f8c44 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -4756,3 +4756,116 @@
 	st.ts.Config.Close()
 	<-donec
 }
+
+func TestServerMaxHandlerGoroutines(t *testing.T) {
+	const maxHandlers = 10
+	handlerc := make(chan chan bool)
+	donec := make(chan struct{})
+	defer close(donec)
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		stopc := make(chan bool, 1)
+		select {
+		case handlerc <- stopc:
+		case <-donec:
+		}
+		select {
+		case shouldPanic := <-stopc:
+			if shouldPanic {
+				panic(http.ErrAbortHandler)
+			}
+		case <-donec:
+		}
+	}, func(s *Server) {
+		s.MaxConcurrentStreams = maxHandlers
+	})
+	defer st.Close()
+
+	st.writePreface()
+	st.writeInitialSettings()
+	st.writeSettingsAck()
+
+	// Make maxHandlers concurrent requests.
+	// Reset them all, but only after the handler goroutines have started.
+	var stops []chan bool
+	streamID := uint32(1)
+	for i := 0; i < maxHandlers; i++ {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      streamID,
+			BlockFragment: st.encodeHeader(),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+		stops = append(stops, <-handlerc)
+		st.fr.WriteRSTStream(streamID, ErrCodeCancel)
+		streamID += 2
+	}
+
+	// Start another request, and immediately reset it.
+	st.writeHeaders(HeadersFrameParam{
+		StreamID:      streamID,
+		BlockFragment: st.encodeHeader(),
+		EndStream:     true,
+		EndHeaders:    true,
+	})
+	st.fr.WriteRSTStream(streamID, ErrCodeCancel)
+	streamID += 2
+
+	// Start another two requests. Don't reset these.
+	for i := 0; i < 2; i++ {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      streamID,
+			BlockFragment: st.encodeHeader(),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+		streamID += 2
+	}
+
+	// The initial maxHandlers handlers are still executing,
+	// so the last two requests don't start any new handlers.
+	select {
+	case <-handlerc:
+		t.Errorf("handler unexpectedly started while maxHandlers are already running")
+	case <-time.After(1 * time.Millisecond):
+	}
+
+	// Tell two handlers to exit.
+	// The pending requests which weren't reset start handlers.
+	stops[0] <- false // normal exit
+	stops[1] <- true  // panic
+	stops = stops[2:]
+	stops = append(stops, <-handlerc)
+	stops = append(stops, <-handlerc)
+
+	// Make a bunch more requests.
+	// Eventually, the server tells us to go away.
+	for i := 0; i < 5*maxHandlers; i++ {
+		st.writeHeaders(HeadersFrameParam{
+			StreamID:      streamID,
+			BlockFragment: st.encodeHeader(),
+			EndStream:     true,
+			EndHeaders:    true,
+		})
+		st.fr.WriteRSTStream(streamID, ErrCodeCancel)
+		streamID += 2
+	}
+Frames:
+	for {
+		f, err := st.readFrame()
+		if err != nil {
+			st.t.Fatal(err)
+		}
+		switch f := f.(type) {
+		case *GoAwayFrame:
+			if f.ErrCode != ErrCodeEnhanceYourCalm {
+				t.Errorf("err code = %v; want %v", f.ErrCode, ErrCodeEnhanceYourCalm)
+			}
+			break Frames
+		default:
+		}
+	}
+
+	for _, s := range stops {
+		close(s)
+	}
+}