blob: 0aad15c5ed2129b4fca23fd122bd0a2ecb8f71a2 [file] [log] [blame]
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httptest
import (
"bufio"
"io/ioutil"
"net"
"net/http"
"testing"
)
type newServerFunc func(http.Handler) *Server
var newServers = map[string]newServerFunc{
"NewServer": NewServer,
"NewTLSServer": NewTLSServer,
// The manual variants of newServer create a Server manually by only filling
// in the exported fields of Server.
"NewServerManual": func(h http.Handler) *Server {
ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
ts.Start()
return ts
},
"NewTLSServerManual": func(h http.Handler) *Server {
ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
ts.StartTLS()
return ts
},
}
func TestServer(t *testing.T) {
for _, name := range []string{"NewServer", "NewServerManual"} {
t.Run(name, func(t *testing.T) {
newServer := newServers[name]
t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
})
}
for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
t.Run(name, func(t *testing.T) {
newServer := newServers[name]
t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
})
}
}
func testServer(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
defer ts.Close()
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}
}
// Issue 12781
func testGetAfterClose(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Fatalf("got %q, want hello", string(got))
}
ts.Close()
res, err = http.Get(ts.URL)
if err == nil {
body, _ := ioutil.ReadAll(res.Body)
t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
}
}
func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
dial := func() net.Conn {
c, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
return c
}
// Keep one connection in StateNew (connected, but not sending anything)
cnew := dial()
defer cnew.Close()
// Keep one connection in StateIdle (idle after a request)
cidle := dial()
defer cidle.Close()
cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
_, err := http.ReadResponse(bufio.NewReader(cidle), nil)
if err != nil {
t.Fatal(err)
}
ts.Close() // test we don't hang here forever.
}
// Issue 14290
func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
var s *Server
s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.CloseClientConnections()
}))
defer s.Close()
res, err := http.Get(s.URL)
if err == nil {
res.Body.Close()
t.Fatalf("Unexpected response: %#v", res)
}
}
// Tests that the Server.Client method works and returns an http.Client that can hit
// NewTLSServer without cert warnings.
func testServerClient(t *testing.T, newTLSServer newServerFunc) {
ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("hello"))
}))
defer ts.Close()
client := ts.Client()
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got))
}
}
// Tests that the Server.Client.Transport interface is implemented
// by a *http.Transport.
func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
// Tests that the TLS Server.Client.Transport interface is implemented
// by a *http.Transport.
func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
client := ts.Client()
if _, ok := client.Transport.(*http.Transport); !ok {
t.Errorf("got %T, want *http.Transport", client.Transport)
}
}
type onlyCloseListener struct {
net.Listener
}
func (onlyCloseListener) Close() error { return nil }
// Issue 19729: panic in Server.Close for values created directly
// without a constructor (so the unexported client field is nil).
func TestServerZeroValueClose(t *testing.T) {
ts := &Server{
Listener: onlyCloseListener{},
Config: &http.Server{},
}
ts.Close() // tests that it doesn't panic
}
func TestTLSServerWithHTTP2(t *testing.T) {
modes := []struct {
name string
wantProto string
}{
{"http1", "HTTP/1.1"},
{"http2", "HTTP/2.0"},
}
for _, tt := range modes {
t.Run(tt.name, func(t *testing.T) {
cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Proto", r.Proto)
}))
switch tt.name {
case "http2":
cst.EnableHTTP2 = true
cst.StartTLS()
default:
cst.Start()
}
defer cst.Close()
res, err := cst.Client().Get(cst.URL)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
t.Fatalf("X-Proto header mismatch:\n\tgot: %q\n\twant: %q", g, w)
}
})
}
}