net/http: add Server.Close & Server.Shutdown for forced & graceful shutdown

Also updates x/net/http2 to git rev 541150 for:

   http2: add support for graceful shutdown of Server
   https://golang.org/cl/32412

   http2: make http2.Server access http1's Server via an interface check
   https://golang.org/cl/32417

Fixes #4674
Fixes #9478

Change-Id: I8021a18dee0ef2fe3946ac1776d2b10d3d429052
Reviewed-on: https://go-review.googlesource.com/32329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go
index 00824e7..fbed450 100644
--- a/src/net/http/export_test.go
+++ b/src/net/http/export_test.go
@@ -87,6 +87,12 @@
 	return
 }
 
+func (t *Transport) IdleConnKeyCountForTesting() int {
+	t.idleMu.Lock()
+	defer t.idleMu.Unlock()
+	return len(t.idleConn)
+}
+
 func (t *Transport) IdleConnStrsForTesting() []string {
 	var ret []string
 	t.idleMu.Lock()
diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go
index da7c025..f8398ad 100644
--- a/src/net/http/h2_bundle.go
+++ b/src/net/http/h2_bundle.go
@@ -2982,10 +2982,6 @@
 	return http2defaultMaxStreams
 }
 
-// List of funcs for ConfigureServer to run. Both h1 and h2 are guaranteed
-// to be non-nil.
-var http2configServerFuncs []func(h1 *Server, h2 *http2Server) error
-
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 //
 // The configuration conf may be nil.
@@ -3512,6 +3508,11 @@
 		sc.idleTimerCh = sc.idleTimer.C
 	}
 
+	var gracefulShutdownCh <-chan struct{}
+	if sc.hs != nil {
+		gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs)
+	}
+
 	go sc.readFrames()
 
 	settingsTimer := time.NewTimer(http2firstSettingsTimeout)
@@ -3539,6 +3540,9 @@
 		case <-settingsTimer.C:
 			sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
 			return
+		case <-gracefulShutdownCh:
+			gracefulShutdownCh = nil
+			sc.goAwayIn(http2ErrCodeNo, 0)
 		case <-sc.shutdownTimerCh:
 			sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
 			return
@@ -3548,6 +3552,10 @@
 		case fn := <-sc.testHookCh:
 			fn(loopNum)
 		}
+
+		if sc.inGoAway && sc.curClientStreams == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+			return
+		}
 	}
 }
 
@@ -3803,7 +3811,7 @@
 			sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
 			continue
 		}
-		if !sc.inGoAway {
+		if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
 			if wr, ok := sc.writeSched.Pop(); ok {
 				sc.startFrameWrite(wr)
 				continue
@@ -3821,14 +3829,23 @@
 
 func (sc *http2serverConn) goAway(code http2ErrCode) {
 	sc.serveG.check()
+	var forceCloseIn time.Duration
+	if code != http2ErrCodeNo {
+		forceCloseIn = 250 * time.Millisecond
+	} else {
+
+		forceCloseIn = 1 * time.Second
+	}
+	sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) {
+	sc.serveG.check()
 	if sc.inGoAway {
 		return
 	}
-	if code != http2ErrCodeNo {
-		sc.shutDownIn(250 * time.Millisecond)
-	} else {
-
-		sc.shutDownIn(1 * time.Second)
+	if forceCloseIn != 0 {
+		sc.shutDownIn(forceCloseIn)
 	}
 	sc.inGoAway = true
 	sc.needToSendGoAway = true
@@ -5264,6 +5281,31 @@
 	"Www-Authenticate":    true,
 }
 
+// h1ServerShutdownChan returns a channel that will be closed when the
+// provided *http.Server wants to shut down.
+//
+// This is a somewhat hacky way to get at http1 innards. It works
+// when the http2 code is bundled into the net/http package in the
+// standard library. The alternatives ended up making the cmd/go tool
+// depend on http Servers. This is the lightest option for now.
+// This is tested via the TestServeShutdown* tests in net/http.
+func http2h1ServerShutdownChan(hs *Server) <-chan struct{} {
+	if fn := http2testh1ServerShutdownChan; fn != nil {
+		return fn(hs)
+	}
+	var x interface{} = hs
+	type I interface {
+		getDoneChan() <-chan struct{}
+	}
+	if hs, ok := x.(I); ok {
+		return hs.getDoneChan()
+	}
+	return nil
+}
+
+// optional test hook for h1ServerShutdownChan.
+var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{}
+
 const (
 	// transportDefaultConnFlow is how many connection-level flow control
 	// tokens we give the server at start-up, past the default 64k.
diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go
index c6c38ff..aaae67c 100644
--- a/src/net/http/http_test.go
+++ b/src/net/http/http_test.go
@@ -12,8 +12,13 @@
 	"os/exec"
 	"reflect"
 	"testing"
+	"time"
 )
 
+func init() {
+	shutdownPollInterval = 5 * time.Millisecond
+}
+
 func TestForeachHeaderElement(t *testing.T) {
 	tests := []struct {
 		in   string
diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go
index 7834478..f855c35 100644
--- a/src/net/http/serve_test.go
+++ b/src/net/http/serve_test.go
@@ -4832,3 +4832,96 @@
 		t.Fatal("copy byte succeeded; want err")
 	}
 }
+
+func get(t *testing.T, c *Client, url string) string {
+	res, err := c.Get(url)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer res.Body.Close()
+	slurp, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	return string(slurp)
+}
+
+// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
+// currently-open connections.
+func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
+	if runtime.GOOS == "nacl" {
+		t.Skip("skipping on nacl; see golang.org/issue/17695")
+	}
+	defer afterTest(t)
+	ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+		io.WriteString(w, r.RemoteAddr)
+	}))
+	defer ts.Close()
+
+	tr := &Transport{}
+	defer tr.CloseIdleConnections()
+	c := &Client{Transport: tr}
+
+	get := func() string { return get(t, c, ts.URL) }
+
+	a1, a2 := get(), get()
+	if a1 != a2 {
+		t.Fatal("expected first two requests on same connection")
+	}
+	var idle0 int
+	if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+		idle0 = tr.IdleConnKeyCountForTesting()
+		return idle0 == 1
+	}) {
+		t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0)
+	}
+
+	ts.Config.SetKeepAlivesEnabled(false)
+
+	var idle1 int
+	if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+		idle1 = tr.IdleConnKeyCountForTesting()
+		return idle1 == 0
+	}) {
+		t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
+	}
+
+	a3 := get()
+	if a3 == a2 {
+		t.Fatal("expected third request on new connection")
+	}
+}
+
+func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) }
+func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) }
+
+func testServerShutdown(t *testing.T, h2 bool) {
+	defer afterTest(t)
+	var doShutdown func() // set later
+	var shutdownRes = make(chan error, 1)
+	cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+		go doShutdown()
+		// Shutdown is graceful, so it should not interrupt
+		// this in-flight response. Add a tiny sleep here to
+		// increase the odds of a failure if shutdown has
+		// bugs.
+		time.Sleep(20 * time.Millisecond)
+		io.WriteString(w, r.RemoteAddr)
+	}))
+	defer cst.close()
+
+	doShutdown = func() {
+		shutdownRes <- cst.ts.Config.Shutdown(context.Background())
+	}
+	get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
+
+	if err := <-shutdownRes; err != nil {
+		t.Fatalf("Shutdown: %v", err)
+	}
+
+	res, err := cst.c.Get(cst.ts.URL)
+	if err == nil {
+		res.Body.Close()
+		t.Fatal("second request should fail. server should be shut down")
+	}
+}
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 51a66c3..eae065f 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -248,6 +248,8 @@
 
 	curReq atomic.Value // of *response (which has a Request in it)
 
+	curState atomic.Value // of ConnectionState
+
 	// mu guards hijackedv
 	mu sync.Mutex
 
@@ -1586,11 +1588,30 @@
 }
 
 func (c *conn) setState(nc net.Conn, state ConnState) {
-	if hook := c.server.ConnState; hook != nil {
+	srv := c.server
+	switch state {
+	case StateNew:
+		srv.trackConn(c, true)
+	case StateHijacked, StateClosed:
+		srv.trackConn(c, false)
+	}
+	c.curState.Store(connStateInterface[state])
+	if hook := srv.ConnState; hook != nil {
 		hook(nc, state)
 	}
 }
 
+// connStateInterface is an array of the interface{} versions of
+// ConnState values, so we can use them in atomic.Values later without
+// paying the cost of shoving their integers in an interface{}.
+var connStateInterface = [...]interface{}{
+	StateNew:      StateNew,
+	StateActive:   StateActive,
+	StateIdle:     StateIdle,
+	StateHijacked: StateHijacked,
+	StateClosed:   StateClosed,
+}
+
 // badRequestError is a literal string (used by in the server in HTML,
 // unescaped) to tell the user why their request was bad. It should
 // be plain text without user info or other embedded errors.
@@ -2247,8 +2268,120 @@
 	ErrorLog *log.Logger
 
 	disableKeepAlives int32     // accessed atomically.
+	inShutdown        int32     // accessed atomically (non-zero means we're in Shutdown)
 	nextProtoOnce     sync.Once // guards setupHTTP2_* init
 	nextProtoErr      error     // result of http2.ConfigureServer if used
+
+	mu         sync.Mutex
+	listeners  map[net.Listener]struct{}
+	activeConn map[*conn]struct{}
+	doneChan   chan struct{}
+}
+
+func (s *Server) getDoneChan() <-chan struct{} {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	return s.getDoneChanLocked()
+}
+
+func (s *Server) getDoneChanLocked() chan struct{} {
+	if s.doneChan == nil {
+		s.doneChan = make(chan struct{})
+	}
+	return s.doneChan
+}
+
+func (s *Server) closeDoneChanLocked() {
+	ch := s.getDoneChanLocked()
+	select {
+	case <-ch:
+		// Already closed. Don't close again.
+	default:
+		// Safe to close here. We're the only closer, guarded
+		// by s.mu.
+		close(ch)
+	}
+}
+
+// Close immediately closes all active net.Listeners and connections,
+// regardless of their state. For a graceful shutdown, use Shutdown.
+func (s *Server) Close() error {
+	s.mu.Lock()
+	defer s.mu.Lock()
+	s.closeDoneChanLocked()
+	err := s.closeListenersLocked()
+	for c := range s.activeConn {
+		c.rwc.Close()
+		delete(s.activeConn, c)
+	}
+	return err
+}
+
+// shutdownPollInterval is how often we poll for quiescence
+// during Server.Shutdown. This is lower during tests, to
+// speed up tests.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+var shutdownPollInterval = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners, then closing all idle connections, and then waiting
+// indefinitely for connections to return to idle and then shut down.
+// If the provided context expires before the shutdown is complete,
+// then the context's error is returned.
+func (s *Server) Shutdown(ctx context.Context) error {
+	atomic.AddInt32(&s.inShutdown, 1)
+	defer atomic.AddInt32(&s.inShutdown, -1)
+
+	s.mu.Lock()
+	lnerr := s.closeListenersLocked()
+	s.closeDoneChanLocked()
+	s.mu.Unlock()
+
+	ticker := time.NewTicker(shutdownPollInterval)
+	defer ticker.Stop()
+	for {
+		if s.closeIdleConns() {
+			return lnerr
+		}
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case <-ticker.C:
+		}
+	}
+}
+
+// closeIdleConns closes all idle connections and reports whether the
+// server is quiescent.
+func (s *Server) closeIdleConns() bool {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	quiescent := true
+	for c := range s.activeConn {
+		st, ok := c.curState.Load().(ConnState)
+		if !ok || st != StateIdle {
+			quiescent = false
+			continue
+		}
+		c.rwc.Close()
+		delete(s.activeConn, c)
+	}
+	return quiescent
+}
+
+func (s *Server) closeListenersLocked() error {
+	var err error
+	for ln := range s.listeners {
+		if cerr := ln.Close(); cerr != nil && err == nil {
+			err = cerr
+		}
+		delete(s.listeners, ln)
+	}
+	return err
 }
 
 // A ConnState represents the state of a client connection to a server.
@@ -2361,6 +2494,8 @@
 	return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
 }
 
+var ErrServerClosed = errors.New("http: Server closed")
+
 // Serve accepts incoming connections on the Listener l, creating a
 // new service goroutine for each. The service goroutines read requests and
 // then call srv.Handler to reply to them.
@@ -2370,7 +2505,8 @@
 // srv.TLSConfig is non-nil and doesn't include the string "h2" in
 // Config.NextProtos, HTTP/2 support is not enabled.
 //
-// Serve always returns a non-nil error.
+// Serve always returns a non-nil error. After Shutdown or Close, the
+// returned error is ErrServerClosed.
 func (srv *Server) Serve(l net.Listener) error {
 	defer l.Close()
 	if fn := testHookServerServe; fn != nil {
@@ -2382,12 +2518,20 @@
 		return err
 	}
 
+	srv.trackListener(l, true)
+	defer srv.trackListener(l, false)
+
 	baseCtx := context.Background() // base is always background, per Issue 16220
 	ctx := context.WithValue(baseCtx, ServerContextKey, srv)
 	ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr())
 	for {
 		rw, e := l.Accept()
 		if e != nil {
+			select {
+			case <-srv.getDoneChan():
+				return ErrServerClosed
+			default:
+			}
 			if ne, ok := e.(net.Error); ok && ne.Temporary() {
 				if tempDelay == 0 {
 					tempDelay = 5 * time.Millisecond
@@ -2410,6 +2554,37 @@
 	}
 }
 
+func (s *Server) trackListener(ln net.Listener, add bool) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.listeners == nil {
+		s.listeners = make(map[net.Listener]struct{})
+	}
+	if add {
+		// If the *Server is being reused after a previous
+		// Close or Shutdown, reset its doneChan:
+		if len(s.listeners) == 0 && len(s.activeConn) == 0 {
+			s.doneChan = nil
+		}
+		s.listeners[ln] = struct{}{}
+	} else {
+		delete(s.listeners, ln)
+	}
+}
+
+func (s *Server) trackConn(c *conn, add bool) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	if s.activeConn == nil {
+		s.activeConn = make(map[*conn]struct{})
+	}
+	if add {
+		s.activeConn[c] = struct{}{}
+	} else {
+		delete(s.activeConn, c)
+	}
+}
+
 func (s *Server) idleTimeout() time.Duration {
 	if s.IdleTimeout != 0 {
 		return s.IdleTimeout
@@ -2425,7 +2600,11 @@
 }
 
 func (s *Server) doKeepAlives() bool {
-	return atomic.LoadInt32(&s.disableKeepAlives) == 0
+	return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
+}
+
+func (s *Server) shuttingDown() bool {
+	return atomic.LoadInt32(&s.inShutdown) != 0
 }
 
 // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
@@ -2435,9 +2614,21 @@
 func (srv *Server) SetKeepAlivesEnabled(v bool) {
 	if v {
 		atomic.StoreInt32(&srv.disableKeepAlives, 0)
-	} else {
-		atomic.StoreInt32(&srv.disableKeepAlives, 1)
+		return
 	}
+	atomic.StoreInt32(&srv.disableKeepAlives, 1)
+
+	// Close idle HTTP/1 conns:
+	srv.closeIdleConns()
+
+	// Close HTTP/2 conns, as soon as they become idle, but reset
+	// the chan so future conns (if the listener is still active)
+	// still work and don't get a GOAWAY immediately, before their
+	// first request:
+	srv.mu.Lock()
+	defer srv.mu.Unlock()
+	srv.closeDoneChanLocked() // closes http2 conns
+	srv.doneChan = nil
 }
 
 func (s *Server) logf(format string, args ...interface{}) {