blob: 038cbc364961b2f552b12980d471f9e077bf321b [file] [log] [blame] [edit]
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package h2c
import (
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"golang.org/x/net/http2"
)
func ExampleNewHandler() {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello world")
})
h2s := &http2.Server{
// ...
}
h1s := &http.Server{
Addr: ":8080",
Handler: NewHandler(handler, h2s),
}
log.Fatal(h1s.ListenAndServe())
}
func TestContext(t *testing.T) {
baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor != 2 {
t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
}
if r.Context().Value("testkey") != "testvalue" {
t.Errorf("Request doesn't have expected base context: %v", r.Context())
}
fmt.Fprint(w, "Hello world")
})
h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
h1s.Config.BaseContext = func(_ net.Listener) context.Context {
return baseCtx
}
h1s.Start()
defer h1s.Close()
client := &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
return net.Dial(network, addr)
},
},
}
resp, err := client.Get(h1s.URL)
if err != nil {
t.Fatal(err)
}
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if err := resp.Body.Close(); err != nil {
t.Fatal(err)
}
}
func TestPropagation(t *testing.T) {
var (
server *http.Server
// double the limit because http2 will compress header
headerSize = 1 << 11
headerLimit = 1 << 10
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor != 2 {
t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
}
if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
t.Errorf("Request doesn't have expected http server: %v", r.Context())
}
if len(r.Header.Get("Long-Header")) != headerSize {
t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
}
fmt.Fprint(w, "Hello world")
})
h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
server = h1s.Config
server.MaxHeaderBytes = headerLimit
server.ConnState = func(conn net.Conn, state http.ConnState) {
t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
}
h1s.Start()
defer h1s.Close()
client := &http.Client{
Transport: &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if conn != nil {
t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
}
return conn, err
},
},
}
req, err := http.NewRequest("GET", h1s.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
_, err = client.Do(req)
if err == nil {
t.Fatal("expected server err, got nil")
}
}
func TestMaxBytesHandler(t *testing.T) {
const bodyLimit = 10
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("got request, expected to be blocked by body limit")
})
h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
h1s.Start()
defer h1s.Close()
// Wrap the body in a struct{io.Reader} to prevent it being rewound and resent.
body := "0123456789abcdef"
req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)})
if err != nil {
t.Fatal(err)
}
req.Header.Set("Http2-Settings", "")
req.Header.Set("Upgrade", "h2c")
req.Header.Set("Connection", "Upgrade, HTTP2-Settings")
resp, err := h1s.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if got, want := resp.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("resp.StatusCode = %v, want %v", got, want)
}
}