http2/h2c: propagate HTTP/1 server configuration to HTTP/2
Fixes golang/go#37089
Change-Id: I793bf8b420fd7b5a47b45ad1521c5b5f9e0321b2
GitHub-Last-Rev: 805b90e36a9a9986a57de86eb8f6725359f7abfe
GitHub-Pull-Request: golang/net#139
Reviewed-on: https://go-review.googlesource.com/c/net/+/419181
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go
index c3df711..2b77ffd 100644
--- a/http2/h2c/h2c.go
+++ b/http2/h2c/h2c.go
@@ -70,6 +70,15 @@
}
}
+// extractServer extracts existing http.Server instance from http.Request or create an empty http.Server
+func extractServer(r *http.Request) *http.Server {
+ server, ok := r.Context().Value(http.ServerContextKey).(*http.Server)
+ if ok {
+ return server
+ }
+ return new(http.Server)
+}
+
// ServeHTTP implement the h2c support that is enabled by h2c.GetH2CHandler.
func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Handle h2c with prior knowledge (RFC 7540 Section 3.4)
@@ -87,6 +96,7 @@
defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(),
+ BaseConfig: extractServer(r),
Handler: s.Handler,
SawClientPreface: true,
})
@@ -104,6 +114,7 @@
defer conn.Close()
s.s.ServeConn(conn, &http2.ServeConnOpts{
Context: r.Context(),
+ BaseConfig: extractServer(r),
Handler: s.Handler,
UpgradeRequest: r,
Settings: settings,
diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go
index 3e5a2eb..558e597 100644
--- a/http2/h2c/h2c_test.go
+++ b/http2/h2c/h2c_test.go
@@ -13,6 +13,7 @@
"net"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"golang.org/x/net/http2"
@@ -74,3 +75,62 @@
t.Fatal(err)
}
}
+
+func TestPropagation(t *testing.T) {
+ var (
+ server *http.Server
+ // double the limit because http2 will compress header
+ headerSize = 1 << 11
+ headerLimit = 1 << 10
+ )
+
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.ProtoMajor != 2 {
+ t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
+ }
+ if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
+ t.Errorf("Request doesn't have expected http server: %v", r.Context())
+ }
+ if len(r.Header.Get("Long-Header")) != headerSize {
+ t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
+ }
+ fmt.Fprint(w, "Hello world")
+ })
+
+ h2s := &http2.Server{}
+ h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
+
+ server = h1s.Config
+ server.MaxHeaderBytes = headerLimit
+ server.ConnState = func(conn net.Conn, state http.ConnState) {
+ t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
+ }
+
+ h1s.Start()
+ defer h1s.Close()
+
+ client := &http.Client{
+ Transport: &http2.Transport{
+ AllowHTTP: true,
+ DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
+ conn, err := net.Dial(network, addr)
+ if conn != nil {
+ t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
+ }
+ return conn, err
+ },
+ },
+ }
+
+ req, err := http.NewRequest("GET", h1s.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
+
+ _, err = client.Do(req)
+ if err == nil {
+ t.Fatal("expected server err, got nil")
+ }
+}