http2: add Transport.MaxReadFrameSize configuration setting
For golang/go#47840.
Fixes golang/go#54850.
Change-Id: I44efec8d1f223b3c2be82a2e11752fbbe8bf2cbf
Reviewed-on: https://go-review.googlesource.com/c/net/+/362834
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Joedian Reid <joedian@golang.org>
diff --git a/http2/transport.go b/http2/transport.go
index 91f4370..30f706e 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -118,6 +118,15 @@
// to mean no limit.
MaxHeaderListSize uint32
+ // MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the
+ // initial settings frame. It is the size in bytes of the largest frame
+ // payload that the sender is willing to receive. If 0, no setting is
+ // sent, and the value is provided by the peer, which should be 16384
+ // according to the spec:
+ // https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2.
+ // Values are bounded in the range 16k to 16M.
+ MaxReadFrameSize uint32
+
// MaxDecoderHeaderTableSize optionally specifies the http2
// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
// informs the remote endpoint of the maximum size of the header compression
@@ -184,6 +193,19 @@
return t.MaxHeaderListSize
}
+func (t *Transport) maxFrameReadSize() uint32 {
+ if t.MaxReadFrameSize == 0 {
+ return 0 // use the default provided by the peer
+ }
+ if t.MaxReadFrameSize < minMaxFrameSize {
+ return minMaxFrameSize
+ }
+ if t.MaxReadFrameSize > maxFrameSize {
+ return maxFrameSize
+ }
+ return t.MaxReadFrameSize
+}
+
func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
@@ -749,6 +771,9 @@
})
cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br)
+ if t.maxFrameReadSize() != 0 {
+ cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
+ }
if t.CountError != nil {
cc.fr.countError = t.CountError
}
@@ -773,6 +798,9 @@
{ID: SettingEnablePush, Val: 0},
{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
}
+ if max := t.maxFrameReadSize(); max != 0 {
+ initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
+ }
if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
}
diff --git a/http2/transport_test.go b/http2/transport_test.go
index ee852b6..42d7dbc 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -3998,6 +3998,64 @@
ct.run()
}
+// Test that the server received values are in range.
+func TestTransportMaxFrameReadSize(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.MaxConcurrentStreams = 1
+ })
+ defer st.Close()
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ MaxReadFrameSize: 64000,
+ }
+ defer tr.CloseIdleConnections()
+
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := res.StatusCode, 200; got != want {
+ t.Errorf("StatusCode = %v; want %v", got, want)
+ }
+ if res != nil && res.Body != nil {
+ res.Body.Close()
+ }
+ // Test that maxFrameReadSize() matches the requested size.
+ if got, want := tr.maxFrameReadSize(), uint32(64000); got != want {
+ t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+ }
+ // Test that server receives the requested size.
+ if got, want := st.sc.maxFrameSize, int32(64000); got != want {
+ t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+ }
+
+ // test for minimum frame read size
+ tr = &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ MaxReadFrameSize: 1024,
+ }
+
+ if got, want := tr.maxFrameReadSize(), uint32(minMaxFrameSize); got != want {
+ t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+ }
+
+ // test for maximum frame size
+ tr = &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ MaxReadFrameSize: 1024 * 1024 * 16,
+ }
+
+ if got, want := tr.maxFrameReadSize(), uint32(maxFrameSize); got != want {
+ t.Errorf("maxFrameReadSize = %v; want %v", got, want)
+ }
+
+}
+
func TestTransportRequestsLowServerLimit(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
}, optOnlyServer, func(s *Server) {
@@ -4608,6 +4666,61 @@
b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
}
+func BenchmarkDownloadFrameSize(b *testing.B) {
+ b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
+ b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
+ b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
+ b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
+ b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
+}
+func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
+ defer disableGoroutineTracking()()
+ const transferSize = 1024 * 1024 * 1024 // must be multiple of 1M
+ b.ReportAllocs()
+ st := newServerTester(b,
+ func(w http.ResponseWriter, r *http.Request) {
+ // test 1GB transfer
+ w.Header().Set("Content-Length", strconv.Itoa(transferSize))
+ w.Header().Set("Content-Transfer-Encoding", "binary")
+ var data [1024 * 1024]byte
+ for i := 0; i < transferSize/(1024*1024); i++ {
+ w.Write(data[:])
+ }
+ }, optQuiet,
+ )
+ defer st.Close()
+
+ tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
+ defer tr.CloseIdleConnections()
+
+ req, err := http.NewRequest("GET", st.ts.URL, nil)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ b.N = 3
+ b.SetBytes(transferSize)
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ if res != nil {
+ res.Body.Close()
+ }
+ b.Fatalf("RoundTrip err = %v; want nil", err)
+ }
+ data, _ := io.ReadAll(res.Body)
+ if len(data) != transferSize {
+ b.Fatalf("Response length invalid")
+ }
+ res.Body.Close()
+ if res.StatusCode != http.StatusOK {
+ b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
+ }
+ }
+}
+
func activeStreams(cc *ClientConn) int {
count := 0
cc.mu.Lock()