http2: limit 1xx based on size, do not limit when delivered

Replace Transport's limit of 5 1xx responses with a limit based
on the maximum header size: The total size of all 1xx response
headers must not exceed the limit we use on the size of the
final response headers.

(This differs slightly from the corresponding HTTP/1 change,
which imposes a limit on all 1xx response headers *plus* the
final response headers. The difference isn't substantial,
and this implementation fits better with the HTTP/2 framer.)

When the user is reading 1xx responses using a Got1xxResponse
client trace hook, disable the limit: Each 1xx response is
individually limited by the header size limit, but there
is no limit on the total number of responses. The user is
responsible for imposing a limit if they want one.

For golang/go#65035

Change-Id: I9c19dbf068e0f580789d952f63113b3d21ad86fc
Reviewed-on: https://go-review.googlesource.com/c/net/+/615295
Reviewed-by: Cherry Mui <cherryyz@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 0c5f64a..e989bd1 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -420,12 +420,12 @@
 	sentHeaders   bool
 
 	// owned by clientConnReadLoop:
-	firstByte    bool  // got the first response byte
-	pastHeaders  bool  // got first MetaHeadersFrame (actual headers)
-	pastTrailers bool  // got optional second MetaHeadersFrame (trailers)
-	num1xx       uint8 // number of 1xx responses seen
-	readClosed   bool  // peer sent an END_STREAM flag
-	readAborted  bool  // read loop reset the stream
+	firstByte       bool  // got the first response byte
+	pastHeaders     bool  // got first MetaHeadersFrame (actual headers)
+	pastTrailers    bool  // got optional second MetaHeadersFrame (trailers)
+	readClosed      bool  // peer sent an END_STREAM flag
+	readAborted     bool  // read loop reset the stream
+	totalHeaderSize int64 // total size of 1xx headers seen
 
 	trailer    http.Header  // accumulated trailers
 	resTrailer *http.Header // client's Response.Trailer
@@ -2494,15 +2494,34 @@
 		if f.StreamEnded() {
 			return nil, errors.New("1xx informational response with END_STREAM flag")
 		}
-		cs.num1xx++
-		const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
-		if cs.num1xx > max1xxResponses {
-			return nil, errors.New("http2: too many 1xx informational responses")
-		}
 		if fn := cs.get1xxTraceFunc(); fn != nil {
+			// If the 1xx response is being delivered to the user,
+			// then they're responsible for limiting the number
+			// of responses.
 			if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
 				return nil, err
 			}
+		} else {
+			// If the user didn't examine the 1xx response, then we
+			// limit the size of all 1xx headers.
+			//
+			// This differs a bit from the HTTP/1 implementation, which
+			// limits the size of all 1xx headers plus the final response.
+			// Use the larger limit of MaxHeaderListSize and
+			// net/http.Transport.MaxResponseHeaderBytes.
+			limit := int64(cs.cc.t.maxHeaderListSize())
+			if t1 := cs.cc.t.t1; t1 != nil && t1.MaxResponseHeaderBytes > limit {
+				limit = t1.MaxResponseHeaderBytes
+			}
+			for _, h := range f.Fields {
+				cs.totalHeaderSize += int64(h.Size())
+			}
+			if cs.totalHeaderSize > limit {
+				if VerboseLogs {
+					log.Printf("http2: 1xx informational responses too large")
+				}
+				return nil, errors.New("header list too large")
+			}
 		}
 		if statusCode == 100 {
 			traceGot100Continue(cs.trace)
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 498e279..757a45a 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -5421,3 +5421,94 @@
 		res.Body.Close()
 	}
 }
+
+func TestTransport1xxLimits(t *testing.T) {
+	for _, test := range []struct {
+		name    string
+		opt     any
+		ctxfn   func(context.Context) context.Context
+		hcount  int
+		limited bool
+	}{{
+		name:    "default",
+		hcount:  10,
+		limited: false,
+	}, {
+		name: "MaxHeaderListSize",
+		opt: func(tr *Transport) {
+			tr.MaxHeaderListSize = 10000
+		},
+		hcount:  10,
+		limited: true,
+	}, {
+		name: "MaxResponseHeaderBytes",
+		opt: func(tr *http.Transport) {
+			tr.MaxResponseHeaderBytes = 10000
+		},
+		hcount:  10,
+		limited: true,
+	}, {
+		name: "limit by client trace",
+		ctxfn: func(ctx context.Context) context.Context {
+			count := 0
+			return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+				Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+					count++
+					if count >= 10 {
+						return errors.New("too many 1xx")
+					}
+					return nil
+				},
+			})
+		},
+		hcount:  10,
+		limited: true,
+	}, {
+		name: "limit disabled by client trace",
+		opt: func(tr *Transport) {
+			tr.MaxHeaderListSize = 10000
+		},
+		ctxfn: func(ctx context.Context) context.Context {
+			return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+				Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+					return nil
+				},
+			})
+		},
+		hcount:  20,
+		limited: false,
+	}} {
+		t.Run(test.name, func(t *testing.T) {
+			tc := newTestClientConn(t, test.opt)
+			tc.greet()
+
+			ctx := context.Background()
+			if test.ctxfn != nil {
+				ctx = test.ctxfn(ctx)
+			}
+			req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
+			rt := tc.roundTrip(req)
+			tc.wantFrameType(FrameHeaders)
+
+			for i := 0; i < test.hcount; i++ {
+				if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
+					t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
+				}
+				tc.writeHeaders(HeadersFrameParam{
+					StreamID:   rt.streamID(),
+					EndHeaders: true,
+					EndStream:  false,
+					BlockFragment: tc.makeHeaderBlockFragment(
+						":status", "103",
+						"x-field", strings.Repeat("a", 1000),
+					),
+				})
+			}
+			if test.limited {
+				tc.wantFrameType(FrameRSTStream)
+			} else {
+				tc.wantIdle()
+			}
+		})
+	}
+}