net/http: add Server.Close & Server.Shutdown for forced & graceful shutdown
Also updates x/net/http2 to git rev 541150 for:
http2: add support for graceful shutdown of Server
https://golang.org/cl/32412
http2: make http2.Server access http1's Server via an interface check
https://golang.org/cl/32417
Fixes #4674
Fixes #9478
Change-Id: I8021a18dee0ef2fe3946ac1776d2b10d3d429052
Reviewed-on: https://go-review.googlesource.com/32329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 00824e7..fbed450 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -87,6 +87,12 @@
return
}
+func (t *Transport) IdleConnKeyCountForTesting() int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return len(t.idleConn)
+}
+
func (t *Transport) IdleConnStrsForTesting() []string {
var ret []string
t.idleMu.Lock()
diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go
index da7c025..f8398ad 100644
--- a/src/net/http/h2_bundle.go
+++ b/src/net/http/h2_bundle.go
@@ -2982,10 +2982,6 @@
return http2defaultMaxStreams
}
-// List of funcs for ConfigureServer to run. Both h1 and h2 are guaranteed
-// to be non-nil.
-var http2configServerFuncs []func(h1 *Server, h2 *http2Server) error
-
// ConfigureServer adds HTTP/2 support to a net/http Server.
//
// The configuration conf may be nil.
@@ -3512,6 +3508,11 @@
sc.idleTimerCh = sc.idleTimer.C
}
+ var gracefulShutdownCh <-chan struct{}
+ if sc.hs != nil {
+ gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs)
+ }
+
go sc.readFrames()
settingsTimer := time.NewTimer(http2firstSettingsTimeout)
@@ -3539,6 +3540,9 @@
case <-settingsTimer.C:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return
+ case <-gracefulShutdownCh:
+ gracefulShutdownCh = nil
+ sc.goAwayIn(http2ErrCodeNo, 0)
case <-sc.shutdownTimerCh:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
@@ -3548,6 +3552,10 @@
case fn := <-sc.testHookCh:
fn(loopNum)
}
+
+ if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+ return
+ }
}
}
@@ -3803,7 +3811,7 @@
sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
continue
}
- if !sc.inGoAway {
+ if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
if wr, ok := sc.writeSched.Pop(); ok {
sc.startFrameWrite(wr)
continue
@@ -3821,14 +3829,23 @@
func (sc *http2serverConn) goAway(code http2ErrCode) {
sc.serveG.check()
+ var forceCloseIn time.Duration
+ if code != http2ErrCodeNo {
+ forceCloseIn = 250 * time.Millisecond
+ } else {
+
+ forceCloseIn = 1 * time.Second
+ }
+ sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) {
+ sc.serveG.check()
if sc.inGoAway {
return
}
- if code != http2ErrCodeNo {
- sc.shutDownIn(250 * time.Millisecond)
- } else {
-
- sc.shutDownIn(1 * time.Second)
+ if forceCloseIn != 0 {
+ sc.shutDownIn(forceCloseIn)
}
sc.inGoAway = true
sc.needToSendGoAway = true
@@ -5264,6 +5281,31 @@
"Www-Authenticate": true,
}
+// h1ServerShutdownChan returns a channel that will be closed when the
+// provided *http.Server wants to shut down.
+//
+// This is a somewhat hacky way to get at http1 innards. It works
+// when the http2 code is bundled into the net/http package in the
+// standard library. The alternatives ended up making the cmd/go tool
+// depend on http Servers. This is the lightest option for now.
+// This is tested via the TestServeShutdown* tests in net/http.
+func http2h1ServerShutdownChan(hs *Server) <-chan struct{} {
+ if fn := http2testh1ServerShutdownChan; fn != nil {
+ return fn(hs)
+ }
+ var x interface{} = hs
+ type I interface {
+ getDoneChan() <-chan struct{}
+ }
+ if hs, ok := x.(I); ok {
+ return hs.getDoneChan()
+ }
+ return nil
+}
+
+// optional test hook for h1ServerShutdownChan.
+var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{}
+
const (
// transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k.
diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go
index c6c38ff..aaae67c 100644
--- a/src/net/http/http_test.go
+++ b/src/net/http/http_test.go
@@ -12,8 +12,13 @@
"os/exec"
"reflect"
"testing"
+ "time"
)
+func init() {
+ shutdownPollInterval = 5 * time.Millisecond
+}
+
func TestForeachHeaderElement(t *testing.T) {
tests := []struct {
in string
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index 7834478..f855c35 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -4832,3 +4832,96 @@
t.Fatal("copy byte succeeded; want err")
}
}
+
+func get(t *testing.T, c *Client, url string) string {
+ res, err := c.Get(url)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+}
+
+// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
+// currently-open connections.
+func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skipping on nacl; see golang.org/issue/17695")
+ }
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, r.RemoteAddr)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ get := func() string { return get(t, c, ts.URL) }
+
+ a1, a2 := get(), get()
+ if a1 != a2 {
+ t.Fatal("expected first two requests on same connection")
+ }
+ var idle0 int
+ if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+ idle0 = tr.IdleConnKeyCountForTesting()
+ return idle0 == 1
+ }) {
+ t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0)
+ }
+
+ ts.Config.SetKeepAlivesEnabled(false)
+
+ var idle1 int
+ if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+ idle1 = tr.IdleConnKeyCountForTesting()
+ return idle1 == 0
+ }) {
+ t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
+ }
+
+ a3 := get()
+ if a3 == a2 {
+ t.Fatal("expected third request on new connection")
+ }
+}
+
+func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) }
+func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) }
+
+func testServerShutdown(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ var doShutdown func() // set later
+ var shutdownRes = make(chan error, 1)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ go doShutdown()
+ // Shutdown is graceful, so it should not interrupt
+ // this in-flight response. Add a tiny sleep here to
+ // increase the odds of a failure if shutdown has
+ // bugs.
+ time.Sleep(20 * time.Millisecond)
+ io.WriteString(w, r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ doShutdown = func() {
+ shutdownRes <- cst.ts.Config.Shutdown(context.Background())
+ }
+ get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
+
+ if err := <-shutdownRes; err != nil {
+ t.Fatalf("Shutdown: %v", err)
+ }
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("second request should fail. server should be shut down")
+ }
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 51a66c3..eae065f 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -248,6 +248,8 @@
curReq atomic.Value // of *response (which has a Request in it)
+ curState atomic.Value // of ConnectionState
+
// mu guards hijackedv
mu sync.Mutex
@@ -1586,11 +1588,30 @@
}
func (c *conn) setState(nc net.Conn, state ConnState) {
- if hook := c.server.ConnState; hook != nil {
+ srv := c.server
+ switch state {
+ case StateNew:
+ srv.trackConn(c, true)
+ case StateHijacked, StateClosed:
+ srv.trackConn(c, false)
+ }
+ c.curState.Store(connStateInterface[state])
+ if hook := srv.ConnState; hook != nil {
hook(nc, state)
}
}
+// connStateInterface is an array of the interface{} versions of
+// ConnState values, so we can use them in atomic.Values later without
+// paying the cost of shoving their integers in an interface{}.
+var connStateInterface = [...]interface{}{
+ StateNew: StateNew,
+ StateActive: StateActive,
+ StateIdle: StateIdle,
+ StateHijacked: StateHijacked,
+ StateClosed: StateClosed,
+}
+
// badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embedded errors.
@@ -2247,8 +2268,120 @@
ErrorLog *log.Logger
disableKeepAlives int32 // accessed atomically.
+ inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
nextProtoOnce sync.Once // guards setupHTTP2_* init
nextProtoErr error // result of http2.ConfigureServer if used
+
+ mu sync.Mutex
+ listeners map[net.Listener]struct{}
+ activeConn map[*conn]struct{}
+ doneChan chan struct{}
+}
+
+func (s *Server) getDoneChan() <-chan struct{} {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.getDoneChanLocked()
+}
+
+func (s *Server) getDoneChanLocked() chan struct{} {
+ if s.doneChan == nil {
+ s.doneChan = make(chan struct{})
+ }
+ return s.doneChan
+}
+
+func (s *Server) closeDoneChanLocked() {
+ ch := s.getDoneChanLocked()
+ select {
+ case <-ch:
+ // Already closed. Don't close again.
+ default:
+ // Safe to close here. We're the only closer, guarded
+ // by s.mu.
+ close(ch)
+ }
+}
+
+// Close immediately closes all active net.Listeners and connections,
+// regardless of their state. For a graceful shutdown, use Shutdown.
+func (s *Server) Close() error {
+ s.mu.Lock()
+ defer s.mu.Lock()
+ s.closeDoneChanLocked()
+ err := s.closeListenersLocked()
+ for c := range s.activeConn {
+ c.rwc.Close()
+ delete(s.activeConn, c)
+ }
+ return err
+}
+
+// shutdownPollInterval is how often we poll for quiescence
+// during Server.Shutdown. This is lower during tests, to
+// speed up tests.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+var shutdownPollInterval = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners, then closing all idle connections, and then waiting
+// indefinitely for connections to return to idle and then shut down.
+// If the provided context expires before the shutdown is complete,
+// then the context's error is returned.
+func (s *Server) Shutdown(ctx context.Context) error {
+ atomic.AddInt32(&s.inShutdown, 1)
+ defer atomic.AddInt32(&s.inShutdown, -1)
+
+ s.mu.Lock()
+ lnerr := s.closeListenersLocked()
+ s.closeDoneChanLocked()
+ s.mu.Unlock()
+
+ ticker := time.NewTicker(shutdownPollInterval)
+ defer ticker.Stop()
+ for {
+ if s.closeIdleConns() {
+ return lnerr
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ }
+ }
+}
+
+// closeIdleConns closes all idle connections and reports whether the
+// server is quiescent.
+func (s *Server) closeIdleConns() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ quiescent := true
+ for c := range s.activeConn {
+ st, ok := c.curState.Load().(ConnState)
+ if !ok || st != StateIdle {
+ quiescent = false
+ continue
+ }
+ c.rwc.Close()
+ delete(s.activeConn, c)
+ }
+ return quiescent
+}
+
+func (s *Server) closeListenersLocked() error {
+ var err error
+ for ln := range s.listeners {
+ if cerr := ln.Close(); cerr != nil && err == nil {
+ err = cerr
+ }
+ delete(s.listeners, ln)
+ }
+ return err
}
// A ConnState represents the state of a client connection to a server.
@@ -2361,6 +2494,8 @@
return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
}
+var ErrServerClosed = errors.New("http: Server closed")
+
// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
@@ -2370,7 +2505,8 @@
// srv.TLSConfig is non-nil and doesn't include the string "h2" in
// Config.NextProtos, HTTP/2 support is not enabled.
//
-// Serve always returns a non-nil error.
+// Serve always returns a non-nil error. After Shutdown or Close, the
+// returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
if fn := testHookServerServe; fn != nil {
@@ -2382,12 +2518,20 @@
return err
}
+ srv.trackListener(l, true)
+ defer srv.trackListener(l, false)
+
baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr())
for {
rw, e := l.Accept()
if e != nil {
+ select {
+ case <-srv.getDoneChan():
+ return ErrServerClosed
+ default:
+ }
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
@@ -2410,6 +2554,37 @@
}
}
+func (s *Server) trackListener(ln net.Listener, add bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.listeners == nil {
+ s.listeners = make(map[net.Listener]struct{})
+ }
+ if add {
+ // If the *Server is being reused after a previous
+ // Close or Shutdown, reset its doneChan:
+ if len(s.listeners) == 0 && len(s.activeConn) == 0 {
+ s.doneChan = nil
+ }
+ s.listeners[ln] = struct{}{}
+ } else {
+ delete(s.listeners, ln)
+ }
+}
+
+func (s *Server) trackConn(c *conn, add bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.activeConn == nil {
+ s.activeConn = make(map[*conn]struct{})
+ }
+ if add {
+ s.activeConn[c] = struct{}{}
+ } else {
+ delete(s.activeConn, c)
+ }
+}
+
func (s *Server) idleTimeout() time.Duration {
if s.IdleTimeout != 0 {
return s.IdleTimeout
@@ -2425,7 +2600,11 @@
}
func (s *Server) doKeepAlives() bool {
- return atomic.LoadInt32(&s.disableKeepAlives) == 0
+ return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
+}
+
+func (s *Server) shuttingDown() bool {
+ return atomic.LoadInt32(&s.inShutdown) != 0
}
// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
@@ -2435,9 +2614,21 @@
func (srv *Server) SetKeepAlivesEnabled(v bool) {
if v {
atomic.StoreInt32(&srv.disableKeepAlives, 0)
- } else {
- atomic.StoreInt32(&srv.disableKeepAlives, 1)
+ return
}
+ atomic.StoreInt32(&srv.disableKeepAlives, 1)
+
+ // Close idle HTTP/1 conns:
+ srv.closeIdleConns()
+
+ // Close HTTP/2 conns, as soon as they become idle, but reset
+ // the chan so future conns (if the listener is still active)
+ // still work and don't get a GOAWAY immediately, before their
+ // first request:
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ srv.closeDoneChanLocked() // closes http2 conns
+ srv.doneChan = nil
}
func (s *Server) logf(format string, args ...interface{}) {