http2: remove extra goroutine stack from awaitGracefulShutdown

This is a better fix that https://golang.org/cl/43455. Instead of
creating a separate goroutine to wait for the global shutdown channel,
we reuse the new serverMsgCh, which was added in a prior CL.

We also use the new net/http.Server.RegisterOnShutdown method to
register a shutdown callback for each http2.Server.

Updates golang/go#20302
Updates golang/go#18471

Change-Id: Icf29d5e4f65b3779d1fb4ea92924e4fb6bdadb2a
Reviewed-on: https://go-review.googlesource.com/43230
Run-TryBot: Tom Bergan <tombergan@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/http2/go19.go b/http2/go19.go
new file mode 100644
index 0000000..38124ba
--- /dev/null
+++ b/http2/go19.go
@@ -0,0 +1,16 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package http2
+
+import (
+	"net/http"
+)
+
+func configureServer19(s *http.Server, conf *Server) error {
+	s.RegisterOnShutdown(conf.state.startGracefulShutdown)
+	return nil
+}
diff --git a/http2/go19_test.go b/http2/go19_test.go
new file mode 100644
index 0000000..1675d24
--- /dev/null
+++ b/http2/go19_test.go
@@ -0,0 +1,60 @@
+// Copyright 2017 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build go1.9
+
+package http2
+
+import (
+	"context"
+	"net/http"
+	"reflect"
+	"testing"
+	"time"
+)
+
+func TestServerGracefulShutdown(t *testing.T) {
+	var st *serverTester
+	handlerDone := make(chan struct{})
+	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+		defer close(handlerDone)
+		go st.ts.Config.Shutdown(context.Background())
+
+		ga := st.wantGoAway()
+		if ga.ErrCode != ErrCodeNo {
+			t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
+		}
+		if ga.LastStreamID != 1 {
+			t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
+		}
+
+		w.Header().Set("x-foo", "bar")
+	})
+	defer st.Close()
+
+	st.greet()
+	st.bodylessReq1()
+
+	select {
+	case <-handlerDone:
+	case <-time.After(5 * time.Second):
+		t.Fatalf("server did not shutdown?")
+	}
+	hf := st.wantHeaders()
+	goth := st.decodeHeader(hf.HeaderBlockFragment())
+	wanth := [][2]string{
+		{":status", "200"},
+		{"x-foo", "bar"},
+		{"content-type", "text/plain; charset=utf-8"},
+		{"content-length", "0"},
+	}
+	if !reflect.DeepEqual(goth, wanth) {
+		t.Errorf("Got headers %v; want %v", goth, wanth)
+	}
+
+	n, err := st.cc.Read([]byte{0})
+	if n != 0 || err == nil {
+		t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+	}
+}
diff --git a/http2/not_go19.go b/http2/not_go19.go
new file mode 100644
index 0000000..5ae0772
--- /dev/null
+++ b/http2/not_go19.go
@@ -0,0 +1,16 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !go1.9
+
+package http2
+
+import (
+	"net/http"
+)
+
+func configureServer19(s *http.Server, conf *Server) error {
+	// not supported prior to go1.9
+	return nil
+}
diff --git a/http2/server.go b/http2/server.go
index 3175a08..7367b31 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -126,6 +126,11 @@
 	// NewWriteScheduler constructs a write scheduler for a connection.
 	// If nil, a default scheduler is chosen.
 	NewWriteScheduler func() WriteScheduler
+
+	// Internal state. This is a pointer (rather than embedded directly)
+	// so that we don't embed a Mutex in this struct, which will make the
+	// struct non-copyable, which might break some callers.
+	state *serverInternalState
 }
 
 func (s *Server) initialConnRecvWindowSize() int32 {
@@ -156,6 +161,40 @@
 	return defaultMaxStreams
 }
 
+type serverInternalState struct {
+	mu          sync.Mutex
+	activeConns map[*serverConn]struct{}
+}
+
+func (s *serverInternalState) registerConn(sc *serverConn) {
+	if s == nil {
+		return // if the Server was used without calling ConfigureServer
+	}
+	s.mu.Lock()
+	s.activeConns[sc] = struct{}{}
+	s.mu.Unlock()
+}
+
+func (s *serverInternalState) unregisterConn(sc *serverConn) {
+	if s == nil {
+		return // if the Server was used without calling ConfigureServer
+	}
+	s.mu.Lock()
+	delete(s.activeConns, sc)
+	s.mu.Unlock()
+}
+
+func (s *serverInternalState) startGracefulShutdown() {
+	if s == nil {
+		return // if the Server was used without calling ConfigureServer
+	}
+	s.mu.Lock()
+	for sc := range s.activeConns {
+		sc.startGracefulShutdown()
+	}
+	s.mu.Unlock()
+}
+
 // ConfigureServer adds HTTP/2 support to a net/http Server.
 //
 // The configuration conf may be nil.
@@ -168,9 +207,13 @@
 	if conf == nil {
 		conf = new(Server)
 	}
+	conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})}
 	if err := configureServer18(s, conf); err != nil {
 		return err
 	}
+	if err := configureServer19(s, conf); err != nil {
+		return err
+	}
 
 	if s.TLSConfig == nil {
 		s.TLSConfig = new(tls.Config)
@@ -305,6 +348,9 @@
 		pushEnabled:                 true,
 	}
 
+	s.state.registerConn(sc)
+	defer s.state.unregisterConn(sc)
+
 	// The net/http package sets the write deadline from the
 	// http.Server.WriteTimeout during the TLS handshake, but then
 	// passes the connection off to us with the deadline already set.
@@ -445,6 +491,9 @@
 	// Owned by the writeFrameAsync goroutine:
 	headerWriteBuf bytes.Buffer
 	hpackEncoder   *hpack.Encoder
+
+	// Used by startGracefulShutdown.
+	shutdownOnce sync.Once
 }
 
 func (sc *serverConn) maxHeaderListSize() uint32 {
@@ -749,15 +798,6 @@
 		defer sc.idleTimer.Stop()
 	}
 
-	var gracefulShutdownCh chan struct{}
-	if sc.hs != nil {
-		ch := h1ServerShutdownChan(sc.hs)
-		if ch != nil {
-			gracefulShutdownCh = make(chan struct{})
-			go sc.awaitGracefulShutdown(ch, gracefulShutdownCh)
-		}
-	}
-
 	go sc.readFrames() // closed by defer sc.conn.Close above
 
 	settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
@@ -786,14 +826,11 @@
 			}
 		case m := <-sc.bodyReadCh:
 			sc.noteBodyRead(m.st, m.n)
-		case <-gracefulShutdownCh:
-			gracefulShutdownCh = nil
-			sc.startGracefulShutdown()
 		case msg := <-sc.serveMsgCh:
 			switch v := msg.(type) {
 			case func(int):
 				v(loopNum) // for testing
-			case *timerMessage:
+			case *serverMessage:
 				switch v {
 				case settingsTimerMsg:
 					sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
@@ -804,6 +841,8 @@
 				case shutdownTimerMsg:
 					sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
 					return
+				case gracefulShutdownMsg:
+					sc.startGracefulShutdownInternal()
 				default:
 					panic("unknown timer")
 				}
@@ -828,13 +867,14 @@
 	}
 }
 
-type timerMessage int
+type serverMessage int
 
-// Timeout message values sent to serveMsgCh.
+// Message values sent to serveMsgCh.
 var (
-	settingsTimerMsg = new(timerMessage)
-	idleTimerMsg     = new(timerMessage)
-	shutdownTimerMsg = new(timerMessage)
+	settingsTimerMsg    = new(serverMessage)
+	idleTimerMsg        = new(serverMessage)
+	shutdownTimerMsg    = new(serverMessage)
+	gracefulShutdownMsg = new(serverMessage)
 )
 
 func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
@@ -1166,10 +1206,19 @@
 	sc.inFrameScheduleLoop = false
 }
 
-// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the
-// client we're gracefully shutting down. The connection isn't closed
-// until all current streams are done.
+// startGracefulShutdown gracefully shuts down a connection. This
+// sends GOAWAY with ErrCodeNo to tell the client we're gracefully
+// shutting down. The connection isn't closed until all current
+// streams are done.
+//
+// startGracefulShutdown returns immediately; it does not wait until
+// the connection has shut down.
 func (sc *serverConn) startGracefulShutdown() {
+	sc.serveG.checkNotOn() // NOT
+	sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
+}
+
+func (sc *serverConn) startGracefulShutdownInternal() {
 	sc.goAwayIn(ErrCodeNo, 0)
 }
 
@@ -1399,7 +1448,7 @@
 			sc.idleTimer.Reset(sc.srv.IdleTimeout)
 		}
 		if h1ServerKeepAlivesDisabled(sc.hs) {
-			sc.startGracefulShutdown()
+			sc.startGracefulShutdownInternal()
 		}
 	}
 	if p := st.body; p != nil {
@@ -1586,7 +1635,7 @@
 	} else {
 		sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
 	}
-	sc.startGracefulShutdown()
+	sc.startGracefulShutdownInternal()
 	// http://tools.ietf.org/html/rfc7540#section-6.8
 	// We should not create any new streams, which means we should disable push.
 	sc.pushEnabled = false
@@ -2653,7 +2702,7 @@
 		// A server that is unable to establish a new stream identifier can send a GOAWAY
 		// frame so that the client is forced to open a new connection for new streams.
 		if sc.maxPushPromiseID+2 >= 1<<31 {
-			sc.startGracefulShutdown()
+			sc.startGracefulShutdownInternal()
 			return 0, ErrPushLimitReached
 		}
 		sc.maxPushPromiseID += 2
@@ -2778,31 +2827,6 @@
 	"Www-Authenticate":    true,
 }
 
-// h1ServerShutdownChan returns a channel that will be closed when the
-// provided *http.Server wants to shut down.
-//
-// This is a somewhat hacky way to get at http1 innards. It works
-// when the http2 code is bundled into the net/http package in the
-// standard library. The alternatives ended up making the cmd/go tool
-// depend on http Servers. This is the lightest option for now.
-// This is tested via the TestServeShutdown* tests in net/http.
-func h1ServerShutdownChan(hs *http.Server) <-chan struct{} {
-	if fn := testh1ServerShutdownChan; fn != nil {
-		return fn(hs)
-	}
-	var x interface{} = hs
-	type I interface {
-		getDoneChan() <-chan struct{}
-	}
-	if hs, ok := x.(I); ok {
-		return hs.getDoneChan()
-	}
-	return nil
-}
-
-// optional test hook for h1ServerShutdownChan.
-var testh1ServerShutdownChan func(hs *http.Server) <-chan struct{}
-
 // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
 // disabled. See comments on h1ServerShutdownChan above for why
 // the code is written this way.
diff --git a/http2/server_test.go b/http2/server_test.go
index 5cb2490..638d2a4 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -3685,48 +3685,3 @@
 		<-done
 	}
 }
-
-func TestServerGracefulShutdown(t *testing.T) {
-	shutdownCh := make(chan struct{})
-	defer func() { testh1ServerShutdownChan = nil }()
-	testh1ServerShutdownChan = func(*http.Server) <-chan struct{} { return shutdownCh }
-
-	var st *serverTester
-	handlerDone := make(chan struct{})
-	st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
-		defer close(handlerDone)
-		close(shutdownCh)
-
-		ga := st.wantGoAway()
-		if ga.ErrCode != ErrCodeNo {
-			t.Errorf("GOAWAY error = %v; want ErrCodeNo", ga.ErrCode)
-		}
-		if ga.LastStreamID != 1 {
-			t.Errorf("GOAWAY LastStreamID = %v; want 1", ga.LastStreamID)
-		}
-
-		w.Header().Set("x-foo", "bar")
-	})
-	defer st.Close()
-
-	st.greet()
-	st.bodylessReq1()
-
-	<-handlerDone
-	hf := st.wantHeaders()
-	goth := st.decodeHeader(hf.HeaderBlockFragment())
-	wanth := [][2]string{
-		{":status", "200"},
-		{"x-foo", "bar"},
-		{"content-type", "text/plain; charset=utf-8"},
-		{"content-length", "0"},
-	}
-	if !reflect.DeepEqual(goth, wanth) {
-		t.Errorf("Got headers %v; want %v", goth, wanth)
-	}
-
-	n, err := st.cc.Read([]byte{0})
-	if n != 0 || err == nil {
-		t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
-	}
-}