http2: add Transport.MaxReadFrameSize configuration setting

For golang/go#47840.
Fixes golang/go#54850.

Change-Id: I44efec8d1f223b3c2be82a2e11752fbbe8bf2cbf
Reviewed-on: https://go-review.googlesource.com/c/net/+/362834
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Joedian Reid <joedian@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 91f4370..30f706e 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -118,6 +118,15 @@
 	// to mean no limit.
 	MaxHeaderListSize uint32
 
+	// MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the
+	// initial settings frame. It is the size in bytes of the largest frame
+	// payload that the sender is willing to receive. If 0, no setting is
+	// sent, and the value is provided by the peer, which should be 16384
+	// according to the spec:
+	// https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2.
+	// Values are bounded in the range 16k to 16M.
+	MaxReadFrameSize uint32
+
 	// MaxDecoderHeaderTableSize optionally specifies the http2
 	// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
 	// informs the remote endpoint of the maximum size of the header compression
@@ -184,6 +193,19 @@
 	return t.MaxHeaderListSize
 }
 
+func (t *Transport) maxFrameReadSize() uint32 {
+	if t.MaxReadFrameSize == 0 {
+		return 0 // use the default provided by the peer
+	}
+	if t.MaxReadFrameSize < minMaxFrameSize {
+		return minMaxFrameSize
+	}
+	if t.MaxReadFrameSize > maxFrameSize {
+		return maxFrameSize
+	}
+	return t.MaxReadFrameSize
+}
+
 func (t *Transport) disableCompression() bool {
 	return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
 }
@@ -749,6 +771,9 @@
 	})
 	cc.br = bufio.NewReader(c)
 	cc.fr = NewFramer(cc.bw, cc.br)
+	if t.maxFrameReadSize() != 0 {
+		cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
+	}
 	if t.CountError != nil {
 		cc.fr.countError = t.CountError
 	}
@@ -773,6 +798,9 @@
 		{ID: SettingEnablePush, Val: 0},
 		{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
 	}
+	if max := t.maxFrameReadSize(); max != 0 {
+		initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
+	}
 	if max := t.maxHeaderListSize(); max != 0 {
 		initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
 	}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index ee852b6..42d7dbc 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3998,6 +3998,64 @@
 	ct.run()
 }
 
+// Test that the server received values are in range.
+func TestTransportMaxFrameReadSize(t *testing.T) {
+	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+	}, func(s *Server) {
+		s.MaxConcurrentStreams = 1
+	})
+	defer st.Close()
+	tr := &Transport{
+		TLSClientConfig:  tlsConfigInsecure,
+		MaxReadFrameSize: 64000,
+	}
+	defer tr.CloseIdleConnections()
+
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+	res, err := tr.RoundTrip(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if got, want := res.StatusCode, 200; got != want {
+		t.Errorf("StatusCode = %v; want %v", got, want)
+	}
+	if res != nil && res.Body != nil {
+		res.Body.Close()
+	}
+	// Test that maxFrameReadSize() matches the requested size.
+	if got, want := tr.maxFrameReadSize(), uint32(64000); got != want {
+		t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+	}
+	// Test that server receives the requested size.
+	if got, want := st.sc.maxFrameSize, int32(64000); got != want {
+		t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+	}
+
+	// test for minimum frame read size
+	tr = &Transport{
+		TLSClientConfig:  tlsConfigInsecure,
+		MaxReadFrameSize: 1024,
+	}
+
+	if got, want := tr.maxFrameReadSize(), uint32(minMaxFrameSize); got != want {
+		t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+	}
+
+	// test for maximum frame size
+	tr = &Transport{
+		TLSClientConfig:  tlsConfigInsecure,
+		MaxReadFrameSize: 1024 * 1024 * 16,
+	}
+
+	if got, want := tr.maxFrameReadSize(), uint32(maxFrameSize); got != want {
+		t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+	}
+
+}
+
 func TestTransportRequestsLowServerLimit(t *testing.T) {
 	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
 	}, optOnlyServer, func(s *Server) {
@@ -4608,6 +4666,61 @@
 	b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
 }
 
+func BenchmarkDownloadFrameSize(b *testing.B) {
+	b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
+	b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
+	b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
+	b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
+	b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
+}
+func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
+	defer disableGoroutineTracking()()
+	const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M
+	b.ReportAllocs()
+	st := newServerTester(b,
+		func(w http.ResponseWriter, r *http.Request) {
+			// test 1GB transfer
+			w.Header().Set("Content-Length", strconv.Itoa(transferSize))
+			w.Header().Set("Content-Transfer-Encoding", "binary")
+			var data [1024 * 1024]byte
+			for i := 0; i < transferSize/(1024*1024); i++ {
+				w.Write(data[:])
+			}
+		}, optQuiet,
+	)
+	defer st.Close()
+
+	tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
+	defer tr.CloseIdleConnections()
+
+	req, err := http.NewRequest("GET", st.ts.URL, nil)
+	if err != nil {
+		b.Fatal(err)
+	}
+
+	b.N = 3
+	b.SetBytes(transferSize)
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		res, err := tr.RoundTrip(req)
+		if err != nil {
+			if res != nil {
+				res.Body.Close()
+			}
+			b.Fatalf("RoundTrip err = %v; want nil", err)
+		}
+		data, _ := io.ReadAll(res.Body)
+		if len(data) != transferSize {
+			b.Fatalf("Response length invalid")
+		}
+		res.Body.Close()
+		if res.StatusCode != http.StatusOK {
+			b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
+		}
+	}
+}
+
 func activeStreams(cc *ClientConn) int {
 	count := 0
 	cc.mu.Lock()