blob: e543672b1e8ab458a0b45d78e80325a88244eaa3 [file] [log] [blame]
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Implementation of Server
package httptest
import (
"bytes"
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/http/internal"
"os"
"sync"
"time"
)
// A Server is an HTTP server listening on a system-chosen port on the
// local loopback interface, for use in end-to-end HTTP tests.
type Server struct {
URL string // base URL of form http://ipaddr:port with no trailing slash
Listener net.Listener
// TLS is the optional TLS configuration, populated with a new config
// after TLS is started. If set on an unstarted server before StartTLS
// is called, existing fields are copied into the new config.
TLS *tls.Config
// Config may be changed after calling NewUnstartedServer and
// before Start or StartTLS.
Config *http.Server
// certificate is a parsed version of the TLS config certificate, if present.
certificate *x509.Certificate
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
mu sync.Mutex // guards closed and conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
// client is configured for use with the server.
// Its transport is automatically closed when Close is called.
client *http.Client
}
func newLocalListener() net.Listener {
if *serve != "" {
l, err := net.Listen("tcp", *serve)
if err != nil {
panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
}
return l
}
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
}
}
return l
}
// When debugging a particular http server-based test,
// this flag lets you run
// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
// to start the broken server so you can interact with it manually.
var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
// NewServer starts and returns a new Server.
// The caller should call Close when finished, to shut it down.
func NewServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.Start()
return ts
}
// NewUnstartedServer returns a new Server but doesn't start it.
//
// After changing its configuration, the caller should call Start or
// StartTLS.
//
// The caller should call Close when finished, to shut it down.
func NewUnstartedServer(handler http.Handler) *Server {
return &Server{
Listener: newLocalListener(),
Config: &http.Server{Handler: handler},
}
}
// Start starts a server from NewUnstartedServer.
func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
if s.client == nil {
s.client = &http.Client{Transport: &http.Transport{}}
}
s.URL = "http://" + s.Listener.Addr().String()
s.wrap()
s.goServe()
if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
}
}
// StartTLS starts TLS on a server from NewUnstartedServer.
func (s *Server) StartTLS() {
if s.URL != "" {
panic("Server already started")
}
if s.client == nil {
s.client = &http.Client{Transport: &http.Transport{}}
}
cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
existingConfig := s.TLS
if existingConfig != nil {
s.TLS = existingConfig.Clone()
} else {
s.TLS = new(tls.Config)
}
if s.TLS.NextProtos == nil {
s.TLS.NextProtos = []string{"http/1.1"}
}
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
certpool := x509.NewCertPool()
certpool.AddCert(s.certificate)
s.client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrap()
s.goServe()
}
// NewTLSServer starts and returns a new Server using TLS.
// The caller should call Close when finished, to shut it down.
func NewTLSServer(handler http.Handler) *Server {
ts := NewUnstartedServer(handler)
ts.StartTLS()
return ts
}
type closeIdleTransport interface {
CloseIdleConnections()
}
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
s.mu.Lock()
if !s.closed {
s.closed = true
s.Listener.Close()
s.Config.SetKeepAlivesEnabled(false)
for c, st := range s.conns {
// Force-close any idle connections (those between
// requests) and new connections (those which connected
// but never sent a request). StateNew connections are
// super rare and have only been seen (in
// previously-flaky tests) in the case of
// socket-late-binding races from the http Client
// dialing this server and then getting an idle
// connection before the dial completed. There is thus
// a connected connection in StateNew with no
// associated Request. We only close StateIdle and
// StateNew because they're not doing anything. It's
// possible StateNew is about to do something in a few
// milliseconds, but a previous CL to check again in a
// few milliseconds wasn't liked (early versions of
// https://golang.org/cl/15151) so now we just
// forcefully close StateNew. The docs for Server.Close say
// we wait for "outstanding requests", so we don't close things
// in StateActive.
if st == http.StateIdle || st == http.StateNew {
s.closeConn(c)
}
}
// If this server doesn't shut down in 20 seconds, tell the user why.
t := time.AfterFunc(20*time.Second, s.logCloseHangDebugInfo)
defer t.Stop()
}
s.mu.Unlock()
// Not part of httptest.Server's correctness, but assume most
// users of httptest.Server will be using the standard
// transport, so help them out and close any idle connections for them.
if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
// Also close the client idle connections.
if s.client != nil {
if t, ok := s.client.Transport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
}
s.wg.Wait()
}
func (s *Server) logCloseHangDebugInfo() {
s.mu.Lock()
defer s.mu.Unlock()
var buf bytes.Buffer
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
for c, st := range s.conns {
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
}
log.Print(buf.String())
}
// CloseClientConnections closes any open HTTP connections to the test Server.
func (s *Server) CloseClientConnections() {
s.mu.Lock()
nconn := len(s.conns)
ch := make(chan struct{}, nconn)
for c := range s.conns {
go s.closeConnChan(c, ch)
}
s.mu.Unlock()
// Wait for outstanding closes to finish.
//
// Out of paranoia for making a late change in Go 1.6, we
// bound how long this can wait, since golang.org/issue/14291
// isn't fully understood yet. At least this should only be used
// in tests.
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
for i := 0; i < nconn; i++ {
select {
case <-ch:
case <-timer.C:
// Too slow. Give up.
return
}
}
}
// Certificate returns the certificate used by the server, or nil if
// the server doesn't use TLS.
func (s *Server) Certificate() *x509.Certificate {
return s.certificate
}
// Client returns an HTTP client configured for making requests to the server.
// It is configured to trust the server's TLS test certificate and will
// close its idle connections on Server.Close.
func (s *Server) Client() *http.Client {
return s.client
}
func (s *Server) goServe() {
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.Config.Serve(s.Listener)
}()
}
// wrap installs the connection state-tracking hook to know which
// connections are idle.
func (s *Server) wrap() {
oldHook := s.Config.ConnState
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
s.mu.Lock()
defer s.mu.Unlock()
switch cs {
case http.StateNew:
s.wg.Add(1)
if _, exists := s.conns[c]; exists {
panic("invalid state transition")
}
if s.conns == nil {
s.conns = make(map[net.Conn]http.ConnState)
}
s.conns[c] = cs
if s.closed {
// Probably just a socket-late-binding dial from
// the default transport that lost the race (and
// thus this connection is now idle and will
// never be used).
s.closeConn(c)
}
case http.StateActive:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateNew && oldState != http.StateIdle {
panic("invalid state transition")
}
s.conns[c] = cs
}
case http.StateIdle:
if oldState, ok := s.conns[c]; ok {
if oldState != http.StateActive {
panic("invalid state transition")
}
s.conns[c] = cs
}
if s.closed {
s.closeConn(c)
}
case http.StateHijacked, http.StateClosed:
s.forgetConn(c)
}
if oldHook != nil {
oldHook(c, cs)
}
}
}
// closeConn closes c.
// s.mu must be held.
func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
// closeConnChan is like closeConn, but takes an optional channel to receive a value
// when the goroutine closing c is done.
func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
c.Close()
if done != nil {
done <- struct{}{}
}
}
// forgetConn removes c from the set of tracked conns and decrements it from the
// waitgroup, unless it was previously removed.
// s.mu must be held.
func (s *Server) forgetConn(c net.Conn) {
if _, ok := s.conns[c]; ok {
delete(s.conns, c)
s.wg.Done()
}
}